diff --git a/.github/workflows/cbrt-formal.yml b/.github/workflows/cbrt-formal.yml new file mode 100644 index 000000000..d4ded797a --- /dev/null +++ b/.github/workflows/cbrt-formal.yml @@ -0,0 +1,66 @@ +name: Cbrt.sol Formal Check + +on: + push: + branches: + - master + paths: + - src/vendor/Cbrt.sol + - src/wrappers/CbrtWrapper.sol + - formal/cbrt/** + - formal/yul_to_lean.py + - test/0.8.25/Cbrt.t.sol + - test/0.8.25/formal-model/CbrtModel.t.sol + - .github/workflows/cbrt-formal.yml + pull_request: + paths: + - src/vendor/Cbrt.sol + - src/wrappers/CbrtWrapper.sol + - formal/cbrt/** + - formal/yul_to_lean.py + - test/0.8.25/Cbrt.t.sol + - test/0.8.25/formal-model/CbrtModel.t.sol + - .github/workflows/cbrt-formal.yml + +jobs: + cbrt-formal: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install Lean toolchain + run: | + curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y + echo "$HOME/.elan/bin" >> "$GITHUB_PATH" + + - name: Generate Lean model from Cbrt.sol via Yul IR + run: | + forge inspect src/wrappers/CbrtWrapper.sol:CbrtWrapper ir | \ + python3 -W error formal/cbrt/generate_cbrt_model.py \ + --yul - \ + --output formal/cbrt/CbrtProof/CbrtProof/GeneratedCbrtModel.lean + + - name: Generate finite certificate from cbrt spec + run: | + python3 formal/cbrt/generate_cbrt_cert.py \ + --output formal/cbrt/CbrtProof/CbrtProof/FiniteCert.lean + + - name: Build Cbrt proof and model evaluator + working-directory: formal/cbrt/CbrtProof + run: lake build && lake build cbrt-model + + - name: Fuzz-test Lean model against Solidity + run: | + FOUNDRY_PROFILE=formal-model forge test \ + --skip 'src/*' --skip 'test/unit/*' --skip 'test/integration/*' --skip 'test/0.8.28/*' \ + --match-contract CbrtModelTest diff --git a/.github/workflows/sqrt-formal.yml b/.github/workflows/sqrt-formal.yml new file mode 100644 index 000000000..33171899f --- /dev/null +++ b/.github/workflows/sqrt-formal.yml @@ -0,0 +1,66 @@ +name: Sqrt.sol Formal Check + +on: + push: + branches: + - master + paths: + - src/vendor/Sqrt.sol + - src/wrappers/SqrtWrapper.sol + - formal/sqrt/** + - formal/yul_to_lean.py + - test/0.8.25/Sqrt.t.sol + - test/0.8.25/formal-model/SqrtModel.t.sol + - .github/workflows/sqrt-formal.yml + pull_request: + paths: + - src/vendor/Sqrt.sol + - src/wrappers/SqrtWrapper.sol + - formal/sqrt/** + - formal/yul_to_lean.py + - test/0.8.25/Sqrt.t.sol + - test/0.8.25/formal-model/SqrtModel.t.sol + - .github/workflows/sqrt-formal.yml + +jobs: + sqrt-formal: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install Lean toolchain + run: | + curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y + echo "$HOME/.elan/bin" >> "$GITHUB_PATH" + + - name: Generate Lean model from Sqrt.sol via Yul IR + run: | + forge inspect src/wrappers/SqrtWrapper.sol:SqrtWrapper ir | \ + python3 -W error formal/sqrt/generate_sqrt_model.py \ + --yul - \ + --output formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtModel.lean + + - name: Generate finite certificate from sqrt spec + run: | + python3 formal/sqrt/generate_sqrt_cert.py \ + --output formal/sqrt/SqrtProof/SqrtProof/FiniteCert.lean + + - name: Build Sqrt proof and model evaluator + working-directory: formal/sqrt/SqrtProof + run: lake build && lake build sqrt-model + + - name: Fuzz-test Lean model against Solidity + run: | + FOUNDRY_PROFILE=formal-model forge test \ + --skip 'src/*' --skip 'test/unit/*' --skip 'test/integration/*' --skip 'test/0.8.28/*' \ + --match-contract SqrtModelTest diff --git a/.github/workflows/sqrt512-formal.yml b/.github/workflows/sqrt512-formal.yml new file mode 100644 index 000000000..7fd133b94 --- /dev/null +++ b/.github/workflows/sqrt512-formal.yml @@ -0,0 +1,74 @@ +name: 512Math sqrt Formal Check + +on: + push: + branches: + - master + paths: + - src/utils/512Math.sol + - src/vendor/Sqrt.sol + - src/wrappers/Sqrt512Wrapper.sol + - formal/sqrt/** + - formal/yul_to_lean.py + - test/0.8.25/formal-model/Sqrt512Model.t.sol + - .github/workflows/sqrt512-formal.yml + pull_request: + paths: + - src/utils/512Math.sol + - src/vendor/Sqrt.sol + - src/wrappers/Sqrt512Wrapper.sol + - formal/sqrt/** + - formal/yul_to_lean.py + - test/0.8.25/formal-model/Sqrt512Model.t.sol + - .github/workflows/sqrt512-formal.yml + +jobs: + sqrt512-formal: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install Lean toolchain + run: | + curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y + echo "$HOME/.elan/bin" >> "$GITHUB_PATH" + + - name: Generate 256-bit Lean model (SqrtProof dependency) + run: | + forge inspect src/wrappers/SqrtWrapper.sol:SqrtWrapper ir | \ + python3 -W error formal/sqrt/generate_sqrt_model.py \ + --yul - \ + --output formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtModel.lean + + - name: Generate finite certificate from sqrt spec + run: | + python3 formal/sqrt/generate_sqrt_cert.py \ + --output formal/sqrt/SqrtProof/SqrtProof/FiniteCert.lean + + - name: Generate 512-bit Lean model from Sqrt512Wrapper via Yul IR + run: | + FOUNDRY_SOLC_VERSION=0.8.33 \ + forge inspect src/wrappers/Sqrt512Wrapper.sol:Sqrt512Wrapper ir | \ + python3 -W error formal/sqrt/generate_sqrt512_model.py \ + --yul - \ + --output formal/sqrt/Sqrt512Proof/Sqrt512Proof/GeneratedSqrt512Model.lean + + - name: Build Sqrt512 proof and model evaluator + working-directory: formal/sqrt/Sqrt512Proof + run: lake build && lake build sqrt512-model + + - name: Fuzz-test Lean model against Solidity + run: | + FOUNDRY_PROFILE=formal-model forge test \ + --skip 'src/*' --skip 'test/unit/*' --skip 'test/integration/*' --skip 'test/0.8.28/*' \ + --match-contract Sqrt512ModelTest diff --git a/.gitignore b/.gitignore index 11c7e59b7..96bb13a3a 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ node_modules # user-specific local configuration /config/ + +# Python bytecode cache directories +__pycache__/ diff --git a/formal/README.md b/formal/README.md new file mode 100644 index 000000000..849d3644b --- /dev/null +++ b/formal/README.md @@ -0,0 +1,53 @@ +# Formal Verification + +Machine-checked Lean 4 correctness proofs for root math libraries in 0x Settler. Zero `sorry`, no axioms beyond the Lean kernel. + +## Scope + +| Proof | Solidity source | What is proved | +|-------|----------------|----------------| +| `sqrt/SqrtProof` | `src/vendor/Sqrt.sol` | `_sqrt`, `sqrt`, `sqrtUp` correct on uint256 | +| `sqrt/Sqrt512Proof` | `src/utils/512Math.sol` | `_sqrt` (512-bit) correct: `sqrt(x_hi * 2^256 + x_lo) = natSqrt(x)` | +| `cbrt/CbrtProof` | `src/vendor/Cbrt.sol` | `_cbrt`, `cbrt`, `cbrtUp` correct on uint256 | + +## Method + +1. **Algebraic lemmas** prove one-step safety and correction logic (Babylonian / Newton-Raphson steps). +2. **Finite domain certificates** (auto-generated by Python scripts) cover all uint256 octaves with `by decide` proofs. +3. **Solidity-to-Lean generators** parse Yul IR into EVM-faithful and normalized Lean models. +4. **End-to-end bridge theorems** prove the generated EVM models equal the hand-written mathematical specs. + +## Build + +All auto-generated files (`GeneratedSqrtModel.lean`, `FiniteCert.lean`, etc.) are `.gitignore`d and regenerated in CI. See `.github/workflows/sqrt-formal.yml`, `sqrt512-formal.yml`, and `cbrt-formal.yml` for the canonical build steps. + +```bash +# --- 256-bit sqrt --- +forge inspect src/wrappers/SqrtWrapper.sol:SqrtWrapper ir | \ + python3 formal/sqrt/generate_sqrt_model.py --yul - \ + --output formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtModel.lean + +python3 formal/sqrt/generate_sqrt_cert.py \ + --output formal/sqrt/SqrtProof/SqrtProof/FiniteCert.lean + +cd formal/sqrt/SqrtProof && lake build + +# --- 512-bit sqrt --- +FOUNDRY_SOLC_VERSION=0.8.33 \ + forge inspect src/wrappers/Sqrt512Wrapper.sol:Sqrt512Wrapper ir | \ + python3 formal/sqrt/generate_sqrt512_model.py --yul - \ + --output formal/sqrt/Sqrt512Proof/Sqrt512Proof/GeneratedSqrt512Model.lean + +# FiniteCert.lean (shared with SqrtProof) must be generated first — see above. +cd formal/sqrt/Sqrt512Proof && lake build + +# --- cbrt --- +forge inspect src/wrappers/CbrtWrapper.sol:CbrtWrapper ir | \ + python3 formal/cbrt/generate_cbrt_model.py --yul - \ + --output formal/cbrt/CbrtProof/CbrtProof/GeneratedCbrtModel.lean + +python3 formal/cbrt/generate_cbrt_cert.py \ + --output formal/cbrt/CbrtProof/CbrtProof/FiniteCert.lean + +cd formal/cbrt/CbrtProof && lake build +``` diff --git a/formal/cbrt/CbrtProof/.gitignore b/formal/cbrt/CbrtProof/.gitignore new file mode 100644 index 000000000..f701d47b9 --- /dev/null +++ b/formal/cbrt/CbrtProof/.gitignore @@ -0,0 +1,8 @@ +/.lake +lake-manifest.json + +# Auto-generated from `formal/cbrt/generate_cbrt_cert.py` +/CbrtProof/FiniteCert.lean + +# Auto-generated from `formal/cbrt/generate_cbrt_model.py` +/CbrtProof/GeneratedCbrtModel.lean diff --git a/formal/cbrt/CbrtProof/CbrtProof.lean b/formal/cbrt/CbrtProof/CbrtProof.lean new file mode 100644 index 000000000..96100a789 --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof.lean @@ -0,0 +1,10 @@ +-- This module serves as the root of the `CbrtProof` library. +-- Import modules here that should be built as part of the library. +import CbrtProof.FloorBound +import CbrtProof.CbrtCorrect +import CbrtProof.FiniteCert +import CbrtProof.CertifiedChain +import CbrtProof.Wiring +import CbrtProof.OverflowSafety +import CbrtProof.GeneratedCbrtModel +import CbrtProof.GeneratedCbrtSpec diff --git a/formal/cbrt/CbrtProof/CbrtProof/CbrtCorrect.lean b/formal/cbrt/CbrtProof/CbrtProof/CbrtCorrect.lean new file mode 100644 index 000000000..d5a6e767b --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof/CbrtCorrect.lean @@ -0,0 +1,747 @@ +/- + Full correctness proof of Cbrt.sol:_cbrt and cbrt. + + This file includes: + 1) A concrete integer cube-root function `icbrt` with formal floor specification. + 2) Explicit (named) correctness theorems for `innerCbrt` and `floorCbrt`, + parameterized by the remaining upper-bound hypothesis + `innerCbrt x ≤ icbrt x + 1`. +-/ +import Init +import CbrtProof.FloorBound + +-- ============================================================================ +-- Part 1: Definitions matching Cbrt.sol EVM semantics +-- ============================================================================ + +/-- One Newton-Raphson step for cube root: ⌊(⌊x/z²⌋ + 2z) / 3⌋. + Matches EVM: div(add(add(div(x, mul(z, z)), z), z), 3) -/ +def cbrtStep (x z : Nat) : Nat := (x / (z * z) + 2 * z) / 3 + +/-- Run five cbrt Newton steps from an explicit starting point. -/ +def run5From (x z : Nat) : Nat := + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + z + +/-- Run six cbrt Newton steps from an explicit starting point. -/ +def run6From (x z : Nat) : Nat := + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + z + +/-- run6From = cbrtStep after run5From (definitional). -/ +theorem run6_eq_step_run5 (x z : Nat) : + run6From x z = cbrtStep x (run5From x z) := rfl + +/-- The cbrt seed: + z = ⌊233 * 2^q / 256⌋ + 1 where q = ⌊(log2(x) + 2) / 3⌋. + Matches EVM: add(shr(8, shl(div(sub(257, clz(x)), 3), 0xe9)), lt(0x00, x)) -/ +def cbrtSeed (x : Nat) : Nat := + (0xe9 <<< ((Nat.log2 x + 2) / 3)) >>> 8 + 1 + +/-- _cbrt: seed + 6 Newton-Raphson steps. -/ +def innerCbrt (x : Nat) : Nat := + let z := cbrtSeed x + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + let z := cbrtStep x z + z + +/-- cbrt: _cbrt with floor correction. + Matches: z := sub(z, lt(div(x, mul(z, z)), z)) -/ +def floorCbrt (x : Nat) : Nat := + let z := innerCbrt x + if x / (z * z) < z then z - 1 else z + +-- ============================================================================ +-- Part 1b: Reference integer cube root (floor) +-- ============================================================================ + +/-- Search helper: largest `m ≤ n` such that `m^3 ≤ x`. -/ +def icbrtAux (x n : Nat) : Nat := + match n with + | 0 => 0 + | n + 1 => if (n + 1) * (n + 1) * (n + 1) ≤ x then n + 1 else icbrtAux x n + +/-- Reference integer cube root (floor). -/ +def icbrt (x : Nat) : Nat := + icbrtAux x x + +theorem cube_monotone {a b : Nat} (h : a ≤ b) : + a * a * a ≤ b * b * b := by + have h1 : a * a * a ≤ b * a * a := by + have hmul : a * a ≤ b * a := Nat.mul_le_mul_right a h + exact Nat.mul_le_mul_right a hmul + have h2 : b * a * a ≤ b * b * a := by + have hmul : b * a ≤ b * b := Nat.mul_le_mul_left b h + exact Nat.mul_le_mul_right a hmul + have h3 : b * b * a ≤ b * b * b := by + exact Nat.mul_le_mul_left (b * b) h + exact Nat.le_trans h1 (Nat.le_trans h2 h3) + +private theorem le_cube_of_pos {a : Nat} (ha : 0 < a) : + a ≤ a * a * a := by + have h1 : 1 ≤ a := Nat.succ_le_of_lt ha + have h2 : a ≤ a * a := by + simpa [Nat.mul_one] using (Nat.mul_le_mul_left a h1) + have h3 : a * a ≤ a * a * a := by + simpa [Nat.mul_one, Nat.mul_assoc] using (Nat.mul_le_mul_left (a * a) h1) + exact Nat.le_trans h2 h3 + +private theorem icbrtAux_cube_le (x n : Nat) : + icbrtAux x n * icbrtAux x n * icbrtAux x n ≤ x := by + induction n with + | zero => simp [icbrtAux] + | succ n ih => + by_cases h : (n + 1) * (n + 1) * (n + 1) ≤ x + · simp [icbrtAux, h] + · simpa [icbrtAux, h] using ih + +private theorem icbrtAux_greatest (x : Nat) : + ∀ n m, m ≤ n → m * m * m ≤ x → m ≤ icbrtAux x n := by + intro n + induction n with + | zero => + intro m hmn hm + have hm0 : m = 0 := by omega + subst hm0 + simp [icbrtAux] + | succ n ih => + intro m hmn hm + by_cases h : (n + 1) * (n + 1) * (n + 1) ≤ x + · simp [icbrtAux, h] + exact hmn + · have hm_le_n : m ≤ n := by + by_cases hm_eq : m = n + 1 + · subst hm_eq + exact False.elim (h hm) + · omega + have hm_le_aux : m ≤ icbrtAux x n := ih m hm_le_n hm + simpa [icbrtAux, h] using hm_le_aux + +/-- Lower half of the floor specification: `icbrt(x)^3 ≤ x`. -/ +theorem icbrt_cube_le (x : Nat) : + icbrt x * icbrt x * icbrt x ≤ x := by + unfold icbrt + exact icbrtAux_cube_le x x + +/-- Upper half of the floor specification: `x < (icbrt(x)+1)^3`. -/ +theorem icbrt_lt_succ_cube (x : Nat) : + x < (icbrt x + 1) * (icbrt x + 1) * (icbrt x + 1) := by + by_cases hlt : x < (icbrt x + 1) * (icbrt x + 1) * (icbrt x + 1) + · exact hlt + · have hle : (icbrt x + 1) * (icbrt x + 1) * (icbrt x + 1) ≤ x := Nat.le_of_not_lt hlt + have hpos : 0 < icbrt x + 1 := by omega + have hmx : icbrt x + 1 ≤ x := by + have hleCube : icbrt x + 1 ≤ (icbrt x + 1) * (icbrt x + 1) * (icbrt x + 1) := + le_cube_of_pos hpos + exact Nat.le_trans hleCube hle + have hmax : icbrt x + 1 ≤ icbrt x := by + unfold icbrt + exact icbrtAux_greatest x x (icbrt x + 1) hmx hle + exact False.elim ((Nat.not_succ_le_self (icbrt x)) hmax) + +/-- Uniqueness: any `r` satisfying the floor specification equals `icbrt(x)`. -/ +theorem icbrt_eq_of_bounds (x r : Nat) + (hlo : r * r * r ≤ x) + (hhi : x < (r + 1) * (r + 1) * (r + 1)) : + r = icbrt x := by + have hrx : r ≤ x := by + by_cases hr0 : r = 0 + · omega + · have hrpos : 0 < r := Nat.pos_of_ne_zero hr0 + have hrle : r ≤ r * r * r := le_cube_of_pos hrpos + exact Nat.le_trans hrle hlo + have h1 : r ≤ icbrt x := by + unfold icbrt + exact icbrtAux_greatest x x r hrx hlo + have h2 : icbrt x ≤ r := by + by_cases hic : icbrt x ≤ r + · exact hic + · have hr1_le : r + 1 ≤ icbrt x := Nat.succ_le_of_lt (Nat.lt_of_not_ge hic) + have hmono : (r + 1) * (r + 1) * (r + 1) ≤ icbrt x * icbrt x * icbrt x := + cube_monotone hr1_le + have hicbrt : icbrt x * icbrt x * icbrt x ≤ x := icbrt_cube_le x + have : (r + 1) * (r + 1) * (r + 1) ≤ x := Nat.le_trans hmono hicbrt + exact False.elim (Nat.not_le_of_lt hhi this) + exact Nat.le_antisymm h1 h2 + +-- ============================================================================ +-- Part 2: Seed and step positivity +-- ============================================================================ + +/-- The cbrt seed is always positive (due to the +1 term). -/ +theorem cbrtSeed_pos (x : Nat) : 0 < cbrtSeed x := by + unfold cbrtSeed + exact Nat.succ_pos _ + +/-- cbrtStep preserves positivity when x > 0 and z > 0. -/ +theorem cbrtStep_pos (x z : Nat) (hx : 0 < x) (hz : 0 < z) : 0 < cbrtStep x z := by + unfold cbrtStep + -- Numerator = x/(z*z) + 2*z ≥ 2*z ≥ 2. + -- For z = 1: numerator = x + 2 ≥ 3, so /3 ≥ 1. + -- For z ≥ 2: numerator ≥ 4, so /3 ≥ 1. + have hzz : 0 < z * z := Nat.mul_pos hz hz + by_cases h : z = 1 + · -- z = 1: numerator = x/1 + 2 = x + 2 ≥ 3, so /3 ≥ 1 + subst h; simp + -- goal: 0 < (x / 1 + 2) / 3 or similar. omega handles. + omega + · -- z ≥ 2: numerator ≥ 0 + 2z ≥ 4, so /3 ≥ 1 + have hz2 : z ≥ 2 := by omega + -- x/(z*z) is a Nat ≥ 0. 2*z ≥ 4. Sum ≥ 4. 4/3 = 1 > 0. + have h_num_ge : x / (z * z) + 2 * z ≥ 3 := by + have : 2 * z ≥ 4 := by omega + have : x / (z * z) ≥ 0 := Nat.zero_le _ + omega + omega + +-- ============================================================================ +-- Part 3: Upper bound machinery (one-step contraction) +-- ============================================================================ + +/-- Integer polynomial identity used to upper-bound one cbrt Newton step. -/ +private theorem pullCoeff (x y c : Int) : x * (y * c) = c * (x * y) := by + rw [← Int.mul_assoc x y c] + rw [Int.mul_comm (x * y) c] + +private theorem pullCoeffNested (x y z c : Int) : x * (y * (z * c)) = c * (x * (y * z)) := by + rw [← Int.mul_assoc y z c] + rw [pullCoeff x (y * z) c] + +private theorem int_poly_identity (m d q r : Int) + (hd2 : d * d = m * q + r) : + ((m - 2 * d + 3 * q + 6) * ((m + d) * (m + d)) - (m + 1) * (m + 1) * (m + 1)) + = + q * (3 * m * q + 6 * m + 3 * r + 4 * d * m) + + (-2 * d * r + 12 * d * m + 3 * m * m - 3 * m * r - 3 * m + 6 * r - 1) := by + simp [Int.sub_eq_add_neg, Int.add_mul, Int.mul_add, + Int.mul_assoc, Int.mul_comm, Int.mul_left_comm] + repeat rw [Int.mul_neg] + repeat rw [Int.neg_mul] + have hddx (x : Int) : d * (d * x) = (d * d) * x := by + rw [← Int.mul_assoc] + simp [hddx, hd2, Int.add_mul, Int.mul_add, + Int.mul_assoc, Int.mul_left_comm] + -- Normalize monomials with numeric coefficients. + rw [pullCoeffNested m m d 2] + rw [pullCoeffNested m m q 2] + rw [pullCoeff m r 2] + rw [pullCoeffNested m d q 2] + rw [pullCoeff d r 2] + rw [pullCoeffNested m m q 3] + rw [pullCoeffNested m d q 3] + rw [pullCoeff m m 6] + rw [pullCoeff m d 6] + rw [pullCoeff m q 6] + rw [pullCoeff m d 12] + rw [pullCoeffNested m d q 4] + rw [pullCoeff m r 3] + rw [pullCoeff m m 3] + -- Collapse the expanded `(m + 1)^3` chunk. + have hcube : + m * (m * m) + m * m + (m * m + m) + (m * m + m + (m + 1)) + = m * (m * m) + 3 * (m * m) + 3 * m + 1 := by + omega + rw [hcube] + omega + +private theorem neg3_mul_mul (m r : Int) : -3 * m * r = -(3 * m * r) := by + calc + -3 * m * r = (-3 * m) * r := by rw [Int.mul_assoc] + _ = (-(3 * m)) * r := by rw [Int.neg_mul] + _ = -(3 * m * r) := by rw [Int.neg_mul, Int.mul_assoc] + +private theorem mul_coeff_expand (m d r : Int) : + r * (-2 * d - 3 * m + 6) = -2 * d * r - 3 * m * r + 6 * r := by + rw [Int.mul_add] + have hsum : -2 * d - 3 * m = (-2 * d) + (-3 * m) := by omega + rw [hsum, Int.mul_add] + rw [Int.mul_comm r (-2 * d), Int.mul_comm r (-3 * m), Int.mul_comm r 6] + repeat rw [Int.sub_eq_add_neg] + rw [neg3_mul_mul] + +/-- Product form of the one-step upper bound (core arithmetic bridge). -/ +private theorem one_step_prod_bound (m d : Nat) (hm2 : 2 ≤ m) : + (m + 1) * (m + 1) * (m + 1) ≤ + (m - 2 * d + 3 * (d * d / m) + 6) * ((m + d) * (m + d)) := by + let q : Nat := d * d / m + let r : Nat := d * d % m + have hm : 0 < m := by omega + have hr : r < m := by + dsimp [r] + exact Nat.mod_lt _ hm + have hd2 : d * d = m * q + r := by + dsimp [q, r] + exact (Nat.div_add_mod (d * d) m).symm + + have hd2i : (d : Int) * (d : Int) = (m : Int) * (q : Int) + (r : Int) := by + exact_mod_cast hd2 + + have hEqInt : + (((m : Int) - 2 * (d : Int) + 3 * (q : Int) + 6) * + (((m : Int) + (d : Int)) * ((m : Int) + (d : Int))) + - ((m : Int) + 1) * ((m : Int) + 1) * ((m : Int) + 1)) + = + (q : Int) * (3 * (m : Int) * (q : Int) + 6 * (m : Int) + 3 * (r : Int) + 4 * (d : Int) * (m : Int)) + + (-2 * (d : Int) * (r : Int) + 12 * (d : Int) * (m : Int) + + 3 * (m : Int) * (m : Int) - 3 * (m : Int) * (r : Int) + - 3 * (m : Int) + 6 * (r : Int) - 1) := by + exact int_poly_identity (m := (m : Int)) (d := (d : Int)) (q := (q : Int)) (r := (r : Int)) hd2i + + have hm_nonneg : 0 ≤ (m : Int) := Int.natCast_nonneg m + have hq_nonneg : 0 ≤ (q : Int) := Int.natCast_nonneg q + have hr_nonneg : 0 ≤ (r : Int) := Int.natCast_nonneg r + have hd_nonneg : 0 ≤ (d : Int) := Int.natCast_nonneg d + + have h3_nonneg : (0 : Int) ≤ 3 := by decide + have h4_nonneg : (0 : Int) ≤ 4 := by decide + have h6_nonneg : (0 : Int) ≤ 6 := by decide + have h10_nonneg : (0 : Int) ≤ 10 := by decide + have h2_nonneg : (0 : Int) ≤ 2 := by decide + + have h3m_nonneg : 0 ≤ 3 * (m : Int) := Int.mul_nonneg h3_nonneg hm_nonneg + have h6m_nonneg : 0 ≤ 6 * (m : Int) := Int.mul_nonneg h6_nonneg hm_nonneg + have h3r_nonneg : 0 ≤ 3 * (r : Int) := Int.mul_nonneg h3_nonneg hr_nonneg + have h4d_nonneg : 0 ≤ 4 * (d : Int) := Int.mul_nonneg h4_nonneg hd_nonneg + + have h1_nonneg : 0 ≤ 3 * (m : Int) * (q : Int) := Int.mul_nonneg h3m_nonneg hq_nonneg + have h4_nonneg' : 0 ≤ 4 * (d : Int) * (m : Int) := Int.mul_nonneg h4d_nonneg hm_nonneg + + have hfac_nonneg : 0 ≤ 3 * (m : Int) * (q : Int) + 6 * (m : Int) + 3 * (r : Int) + 4 * (d : Int) * (m : Int) := by + omega + + have hQ_nonneg : + 0 ≤ (q : Int) * (3 * (m : Int) * (q : Int) + 6 * (m : Int) + 3 * (r : Int) + 4 * (d : Int) * (m : Int)) := by + exact Int.mul_nonneg hq_nonneg hfac_nonneg + + have hc_nonpos : -2 * (d : Int) - 3 * (m : Int) + 6 ≤ 0 := by + have hm_ge_two : (2 : Int) ≤ (m : Int) := by exact_mod_cast hm2 + omega + + have hr_le : (r : Int) ≤ ((m - 1 : Nat) : Int) := by + have : r ≤ m - 1 := by omega + exact Int.ofNat_le.mpr this + + have h_mul_lower : + ((m - 1 : Nat) : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) + ≤ (r : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) := by + exact Int.mul_le_mul_of_nonpos_right hr_le hc_nonpos + + have h_rewrite : + (-2 * (d : Int) * (r : Int) + 12 * (d : Int) * (m : Int) + + 3 * (m : Int) * (m : Int) - 3 * (m : Int) * (r : Int) + - 3 * (m : Int) + 6 * (r : Int) - 1) + = (12 * (d : Int) * (m : Int) + 3 * (m : Int) * (m : Int) - 3 * (m : Int) - 1) + + (r : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) := by + rw [mul_coeff_expand (m := (m : Int)) (d := (d : Int)) (r := (r : Int))] + repeat rw [Int.sub_eq_add_neg] + ac_rfl + + have h_rewrite0 : + (12 * (d : Int) * (m : Int) + 3 * (m : Int) * (m : Int) - 3 * (m : Int) - 1) + + ((m - 1 : Nat) : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) + = 10 * (d : Int) * (m : Int) + 2 * (d : Int) + 6 * (m : Int) - 7 := by + have hm1 : 1 ≤ m := Nat.le_trans (by decide : 1 ≤ 2) hm2 + have ht : ((m - 1 : Nat) : Int) = (m : Int) - 1 := by omega + rw [ht, Int.sub_mul, Int.one_mul] + rw [mul_coeff_expand (m := (m : Int)) (d := (d : Int)) (r := (m : Int))] + repeat rw [Int.sub_eq_add_neg] + have hneg : -(-2 * (d : Int) + -(3 * (m : Int)) + 6) = 2 * (d : Int) + 3 * (m : Int) - 6 := by + omega + rw [hneg] + rw [Int.mul_assoc 12 (d : Int) (m : Int)] + rw [Int.mul_assoc (-2) (d : Int) (m : Int)] + rw [Int.mul_assoc 10 (d : Int) (m : Int)] + rw [Int.mul_assoc 3 (m : Int) (m : Int)] + omega + + have h10dm_nonneg : 0 ≤ 10 * (d : Int) * (m : Int) := by + have h10d_nonneg : 0 ≤ 10 * (d : Int) := Int.mul_nonneg h10_nonneg hd_nonneg + exact Int.mul_nonneg h10d_nonneg hm_nonneg + have h2d_nonneg : 0 ≤ 2 * (d : Int) := Int.mul_nonneg h2_nonneg hd_nonneg + have h6m_minus7_nonneg : 0 ≤ 6 * (m : Int) - 7 := by + have hm_ge_two : (2 : Int) ≤ (m : Int) := by exact_mod_cast hm2 + omega + + have h0 : + 0 ≤ (12 * (d : Int) * (m : Int) + 3 * (m : Int) * (m : Int) - 3 * (m : Int) - 1) + + ((m - 1 : Nat) : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) := by + rw [h_rewrite0] + omega + + have hLin : + 0 ≤ (-2 * (d : Int) * (r : Int) + 12 * (d : Int) * (m : Int) + + 3 * (m : Int) * (m : Int) - 3 * (m : Int) * (r : Int) + - 3 * (m : Int) + 6 * (r : Int) - 1) := by + rw [h_rewrite] + have h_add : + (12 * (d : Int) * (m : Int) + 3 * (m : Int) * (m : Int) - 3 * (m : Int) - 1) + + ((m - 1 : Nat) : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) + ≤ + (12 * (d : Int) * (m : Int) + 3 * (m : Int) * (m : Int) - 3 * (m : Int) - 1) + + (r : Int) * (-2 * (d : Int) - 3 * (m : Int) + 6) := by + exact Int.add_le_add_left h_mul_lower _ + exact Int.le_trans h0 h_add + + have hdiff_nonneg : + 0 ≤ (((m : Int) - 2 * (d : Int) + 3 * (q : Int) + 6) * + (((m : Int) + (d : Int)) * ((m : Int) + (d : Int))) + - ((m : Int) + 1) * ((m : Int) + 1) * ((m : Int) + 1)) := by + rw [hEqInt] + exact Int.add_nonneg hQ_nonneg hLin + + have hIntMain : + ((m : Int) + 1) * ((m : Int) + 1) * ((m : Int) + 1) ≤ + ((m : Int) - 2 * (d : Int) + 3 * (q : Int) + 6) * + (((m : Int) + (d : Int)) * ((m : Int) + (d : Int))) := by + omega + + have hCoeffLe : + ((m : Int) - 2 * (d : Int) + 3 * (q : Int) + 6) + ≤ ((m - 2 * d + 3 * q + 6 : Nat) : Int) := by + omega + + have hz_nonneg : 0 ≤ (((m : Int) + (d : Int)) * ((m : Int) + (d : Int))) := by + have : 0 ≤ (m : Int) + (d : Int) := Int.add_nonneg hm_nonneg hd_nonneg + exact Int.mul_nonneg this this + + have hIntNatCoeff : + ((m : Int) + 1) * ((m : Int) + 1) * ((m : Int) + 1) + ≤ ((m - 2 * d + 3 * q + 6 : Nat) : Int) * + (((m : Int) + (d : Int)) * ((m : Int) + (d : Int))) := by + exact Int.le_trans hIntMain (Int.mul_le_mul_of_nonneg_right hCoeffLe hz_nonneg) + + exact_mod_cast hIntNatCoeff + +/-- Division form of the one-step upper bound. -/ +private theorem one_step_div_bound (m d : Nat) (hm2 : 2 ≤ m) : + (((m + 1) * (m + 1) * (m + 1) - 1) / ((m + d) * (m + d))) + ≤ m - 2 * d + 3 * (d * d / m) + 5 := by + let A : Nat := m - 2 * d + 3 * (d * d / m) + 5 + let B : Nat := (m + d) * (m + d) + have hBpos : 0 < B := by + dsimp [B] + exact Nat.mul_pos (by omega) (by omega) + have hprod : (m + 1) * (m + 1) * (m + 1) ≤ (A + 1) * B := by + dsimp [A, B] + simpa [Nat.add_assoc] using one_step_prod_bound m d hm2 + have hpred : (m + 1) * (m + 1) * (m + 1) - 1 < (m + 1) * (m + 1) * (m + 1) := by + have hpos : 0 < (m + 1) * (m + 1) * (m + 1) := by + have hm1 : 0 < m + 1 := by omega + exact Nat.mul_pos (Nat.mul_pos hm1 hm1) hm1 + exact Nat.sub_lt hpos (by omega) + have hlt : (m + 1) * (m + 1) * (m + 1) - 1 < (A + 1) * B := + Nat.lt_of_lt_of_le hpred hprod + have hdivlt : (((m + 1) * (m + 1) * (m + 1) - 1) / B) < A + 1 := by + exact (Nat.div_lt_iff_lt_mul hBpos).2 hlt + have hdivle : (((m + 1) * (m + 1) * (m + 1) - 1) / B) ≤ A := by + exact Nat.lt_succ_iff.mp hdivlt + simpa [A, B] + +/-- If `x < (m+1)^3` and `z = m+d` with `2d ≤ m`, one cbrt step keeps + the overestimate within `d^2/m + 1`. -/ +private theorem cbrtStep_upper_of_delta + (x m d : Nat) + (hm2 : 2 ≤ m) + (h2d : 2 * d ≤ m) + (hx : x < (m + 1) * (m + 1) * (m + 1)) : + cbrtStep x (m + d) ≤ m + (d * d / m) + 1 := by + let q : Nat := d * d / m + let z : Nat := m + d + have hxle : x ≤ (m + 1) * (m + 1) * (m + 1) - 1 := by omega + have hdiv_x : x / (z * z) ≤ ((m + 1) * (m + 1) * (m + 1) - 1) / (z * z) := + Nat.div_le_div_right hxle + have hdiv_m : + ((m + 1) * (m + 1) * (m + 1) - 1) / (z * z) ≤ m - 2 * d + 3 * q + 5 := by + simpa [z, q, Nat.mul_assoc] using one_step_div_bound m d hm2 + have hdiv : x / (z * z) ≤ m - 2 * d + 3 * q + 5 := Nat.le_trans hdiv_x hdiv_m + unfold cbrtStep + have hsum : x / (z * z) + 2 * z ≤ (m - 2 * d + 3 * q + 5) + 2 * z := by + exact Nat.add_le_add_right hdiv _ + have hdiv3 : + (x / (z * z) + 2 * z) / 3 + ≤ ((m - 2 * d + 3 * q + 5) + 2 * z) / 3 := + Nat.div_le_div_right hsum + have hfinal : ((m - 2 * d + 3 * q + 5) + 2 * z) / 3 ≤ m + q + 1 := by + omega + have hz : z = m + d := by rfl + rw [hz] at hdiv3 hfinal + simpa [q] using Nat.le_trans hdiv3 hfinal + +/-- Upper-bound transfer form: if `z` is between `m` and `m+d`, one cbrt step is + bounded by the same `d^2/m + 1` expression. -/ +theorem cbrtStep_upper_of_le + (x m z d : Nat) + (hm2 : 2 ≤ m) + (hmz : m ≤ z) + (hzd : z ≤ m + d) + (h2d : 2 * d ≤ m) + (hx : x < (m + 1) * (m + 1) * (m + 1)) : + cbrtStep x z ≤ m + (d * d / m) + 1 := by + let d' : Nat := z - m + have hz_eq : z = m + d' := by + dsimp [d'] + omega + have hd'_le : d' ≤ d := by + dsimp [d'] + omega + have h2d' : 2 * d' ≤ m := Nat.le_trans (Nat.mul_le_mul_left 2 hd'_le) h2d + have hstep' : cbrtStep x z ≤ m + (d' * d' / m) + 1 := by + rw [hz_eq] + exact cbrtStep_upper_of_delta x m d' hm2 h2d' hx + have hsq : d' * d' ≤ d * d := Nat.mul_le_mul hd'_le hd'_le + have hdiv : d' * d' / m ≤ d * d / m := Nat.div_le_div_right hsq + have hmono : m + (d' * d' / m) + 1 ≤ m + (d * d / m) + 1 := by + exact Nat.add_le_add_left (Nat.add_le_add_right hdiv 1) m + exact Nat.le_trans hstep' hmono + +-- ============================================================================ +-- Part 4: innerCbrt structure +-- ============================================================================ + +/-- `_cbrt` is exactly `run6From` from the seed (definitional). -/ +theorem innerCbrt_eq_run6From_seed (x : Nat) : + innerCbrt x = run6From x (cbrtSeed x) := rfl + +/-- `_cbrt` is `cbrtStep` applied to `run5From` of the seed (definitional). -/ +theorem innerCbrt_eq_step_run5_seed (x : Nat) : + innerCbrt x = cbrtStep x (run5From x (cbrtSeed x)) := rfl + +set_option maxRecDepth 1000000 in +/-- Direct finite check for small inputs. -/ +private theorem innerCbrt_upper_fin256 : + ∀ i : Fin 256, innerCbrt i.val ≤ icbrt i.val + 1 := by + decide + +/-- Small-range corollary (used for base cases). -/ +theorem innerCbrt_upper_of_lt_256 (x : Nat) (hx : x < 256) : + innerCbrt x ≤ icbrt x + 1 := by + simpa using innerCbrt_upper_fin256 ⟨x, hx⟩ + +/-- innerCbrt gives a lower bound: for any m with m³ ≤ x, m ≤ innerCbrt(x). -/ +theorem innerCbrt_lower (x m : Nat) (hx : 0 < x) + (hm : m * m * m ≤ x) : m ≤ innerCbrt x := by + unfold innerCbrt + have hs := cbrtSeed_pos x + have h1 := cbrtStep_pos x _ hx hs + have h2 := cbrtStep_pos x _ hx h1 + have h3 := cbrtStep_pos x _ hx h2 + have h4 := cbrtStep_pos x _ hx h3 + have h5 := cbrtStep_pos x _ hx h4 + exact cbrt_step_floor_bound x _ m h5 hm + +-- ============================================================================ +-- Part 5: Main correctness theorems (under explicit upper-bound hypothesis) +-- ============================================================================ + +/-- Positivity of `innerCbrt` for positive `x`. -/ +theorem innerCbrt_pos (x : Nat) (hx : 0 < x) : 0 < innerCbrt x := by + have h1 : (1 : Nat) * 1 * 1 ≤ x := by omega + have h := innerCbrt_lower x 1 hx h1 + omega + +/-- If `innerCbrt x` is at most `icbrt x + 1`, then it is exactly one of those two values. -/ +theorem innerCbrt_correct_of_upper (x : Nat) (hx : 0 < x) + (hupper : innerCbrt x ≤ icbrt x + 1) : + innerCbrt x = icbrt x ∨ innerCbrt x = icbrt x + 1 := by + have hlow : icbrt x ≤ innerCbrt x := innerCbrt_lower x (icbrt x) hx (icbrt_cube_le x) + by_cases heq : innerCbrt x = icbrt x + · exact Or.inl heq + · have hneq : icbrt x ≠ innerCbrt x := by + intro h' + exact heq h'.symm + have hlt : icbrt x < innerCbrt x := Nat.lt_of_le_of_ne hlow hneq + have hge : icbrt x + 1 ≤ innerCbrt x := Nat.succ_le_of_lt hlt + exact Or.inr (Nat.le_antisymm hupper hge) + +/-- Useful consequence of the upper-bound hypothesis: `(innerCbrt x - 1)^3 ≤ x`. -/ +theorem innerCbrt_pred_cube_le_of_upper (x : Nat) + (hupper : innerCbrt x ≤ icbrt x + 1) : + (innerCbrt x - 1) * (innerCbrt x - 1) * (innerCbrt x - 1) ≤ x := by + have hpred_le : innerCbrt x - 1 ≤ icbrt x := by omega + have hmono : + (innerCbrt x - 1) * (innerCbrt x - 1) * (innerCbrt x - 1) ≤ + icbrt x * icbrt x * icbrt x := cube_monotone hpred_le + exact Nat.le_trans hmono (icbrt_cube_le x) + +/-- For positive `x`, `x` is strictly below `(innerCbrt x + 1)^3`. -/ +theorem innerCbrt_lt_succ_cube (x : Nat) (hx : 0 < x) : + x < (innerCbrt x + 1) * (innerCbrt x + 1) * (innerCbrt x + 1) := by + by_cases hlt : x < (innerCbrt x + 1) * (innerCbrt x + 1) * (innerCbrt x + 1) + · exact hlt + · have hle : (innerCbrt x + 1) * (innerCbrt x + 1) * (innerCbrt x + 1) ≤ x := Nat.le_of_not_lt hlt + have hcontra : innerCbrt x + 1 ≤ innerCbrt x := innerCbrt_lower x (innerCbrt x + 1) hx hle + exact False.elim ((Nat.not_succ_le_self (innerCbrt x)) hcontra) + +-- ============================================================================ +-- Part 6: Perfect-cube exactness (innerCbrt(m³) = m) +-- ============================================================================ + +/-- cbrtStep is a fixed point at the exact cube root: cbrtStep(m³, m) = m. -/ +theorem cbrtStep_fixed_point_on_perfect_cube + (m : Nat) (hm : 0 < m) : + cbrtStep (m * m * m) m = m := by + unfold cbrtStep + have hzz : 0 < m * m := Nat.mul_pos hm hm + have hdiv : m * m * m / (m * m) = m := by + rw [Nat.mul_assoc] + exact Nat.mul_div_cancel m hzz + rw [hdiv] + omega + +/-- On a perfect cube, cbrtStep from m+d with d² < m gives exactly m. + Key: m³ < (m-2d+3)(m+d)², so floor(m³/(m+d)²) ≤ m-2d+2, + giving numerator ≤ 3m+2, so step = m. + The strict inequality follows from 3(m+d)² > d²(3m+2d) when d² < m. -/ +theorem cbrtStep_eq_on_perfect_cube_of_sq_lt + (m d : Nat) (hm : 2 ≤ m) (h2d : 2 * d ≤ m) (hdsq : d * d < m) : + cbrtStep (m * m * m) (m + d) = m := by + by_cases hd0 : d = 0 + · subst hd0; simp only [Nat.add_zero] + exact cbrtStep_fixed_point_on_perfect_cube m (by omega) + · have hd : 0 < d := Nat.pos_of_ne_zero hd0 + -- Lower bound from floor bound + have hlo : m ≤ cbrtStep (m * m * m) (m + d) := + cbrt_step_floor_bound (m * m * m) (m + d) m (by omega) (Nat.le_refl _) + -- Upper bound: suffices numerator ≤ 3m+2 + suffices hup : cbrtStep (m * m * m) (m + d) ≤ m by omega + unfold cbrtStep + let z := m + d + have hz : 0 < z := by omega + have hzz : 0 < z * z := Nat.mul_pos hz hz + -- Goal: (m*m*m / (z*z) + 2*z) / 3 ≤ m, i.e., numerator ≤ 3m+2 + -- Strategy: show m³ < (m-2d+3)*z², so m³/z² ≤ m-2d+2, so num ≤ 3m+2, so step ≤ m. + -- Step 1: d²(3m+2d) < 3z² (the key inequality using d² < m) + have hkey : d * d * (3 * m + 2 * d) < 3 * (z * z) := by + -- d²(3m+2d) < m(3m+2d) ≤ 3(m+d)² + have h3m2d : 0 < 3 * m + 2 * d := by omega + have hstep1 : d * d * (3 * m + 2 * d) < m * (3 * m + 2 * d) := + Nat.mul_lt_mul_of_pos_right hdsq h3m2d + -- m(3m+2d) ≤ 3(m+d)²: expand both sides to mm + md + dd terms + have hstep2 : m * (3 * m + 2 * d) ≤ 3 * (z * z) := by + show m * (3 * m + 2 * d) ≤ 3 * ((m + d) * (m + d)) + -- LHS = 3mm + 2md. RHS = 3mm + 6md + 3dd. Diff = 4md + 3dd ≥ 0. + have hLmm : m * (3 * m) = 3 * (m * m) := by + rw [← Nat.mul_assoc, Nat.mul_comm m 3, Nat.mul_assoc] + have hLmd : m * (2 * d) = 2 * (m * d) := by + rw [← Nat.mul_assoc, Nat.mul_comm m 2, Nat.mul_assoc] + have hR : (m + d) * (m + d) = m * m + 2 * (m * d) + d * d := by + rw [Nat.add_mul, Nat.mul_add, Nat.mul_add, Nat.mul_comm d m]; omega + rw [Nat.mul_add m, hLmm, hLmd, hR]; omega + exact Nat.lt_of_lt_of_le hstep1 hstep2 + -- Step 2: polynomial identity m³ = z²(m-2d) + d²(3m+2d) + -- Substitute a = m - 2d to eliminate Nat subtraction, then expand both sides. + have hident : m * m * m = z * z * (m - 2 * d) + d * d * (3 * m + 2 * d) := by + show m * m * m = (m + d) * (m + d) * (m - 2 * d) + d * d * (3 * m + 2 * d) + generalize ha : m - 2 * d = a + have hm_eq : m = a + 2 * d := by omega + subst hm_eq + -- Both sides expand to a³+6a²d+12ad²+8d³ + have h3 : (a + 2 * d) + d = a + 3 * d := by omega + have h8 : 3 * (a + 2 * d) + 2 * d = 3 * a + 8 * d := by omega + rw [h3, h8] + rw [show 2 * d = d + d from by omega, + show 3 * d = d + (d + d) from by omega, + show 3 * a = a + (a + a) from by omega, + show 8 * d = d + (d + (d + (d + (d + (d + (d + d)))))) from by omega] + simp only [Nat.add_mul, Nat.mul_add] + simp only [Nat.mul_assoc] + simp only [Nat.mul_comm d a, Nat.mul_left_comm d a] + omega + -- Step 3: combine identity + key inequality to get m³ < (m-2d+3)*z² + have hlt : m * m * m < (m - 2 * d + 3) * (z * z) := by + calc m * m * m + = z * z * (m - 2 * d) + d * d * (3 * m + 2 * d) := hident + _ < z * z * (m - 2 * d) + 3 * (z * z) := Nat.add_lt_add_left hkey _ + _ = z * z * (m - 2 * d) + z * z * 3 := by rw [Nat.mul_comm 3 _] + _ = z * z * (m - 2 * d + 3) := by rw [← Nat.mul_add] + _ = (m - 2 * d + 3) * (z * z) := Nat.mul_comm _ _ + -- Step 4: from m³ < (m-2d+3)*z², derive m³/z² < m-2d+3, so m³/z² ≤ m-2d+2 + have hdiv_lt : m * m * m / (z * z) < m - 2 * d + 3 := + (Nat.div_lt_iff_lt_mul hzz).2 hlt + have hdiv_le : m * m * m / (z * z) ≤ m - 2 * d + 2 := by omega + -- numerator = floor(m³/z²) + 2z ≤ (m-2d+2) + 2(m+d) = 3m+2 + have hnum_le : m * m * m / (z * z) + 2 * z ≤ 3 * m + 2 := by omega + -- (3m+2)/3 ≤ m: use Nat.div_le_div_right on the numerator bound + exact Nat.le_trans (Nat.div_le_div_right hnum_le) (by omega) + +set_option maxRecDepth 1000000 in +/-- Finite check: innerCbrt(m³) = m for all m ≤ 255 (m³ < 2^24). -/ +theorem innerCbrt_on_perfect_cube_small : + ∀ i : Fin 256, innerCbrt (i.val * i.val * i.val) = i.val := by + decide + +-- ============================================================================ +-- Part 7: Floor correction +-- ============================================================================ + +/-- The cbrt floor correction is correct. + Given z > 0, (z-1)³ ≤ x < (z+1)³, the correction gives icbrt(x). + Correction: if x/(z*z) < z then z-1 else z. + When x/(z*z) < z: z³ > x, so z is a ceiling → return z-1. + When x/(z*z) ≥ z: z³ ≤ x, so z is the floor → return z. -/ +theorem cbrt_floor_correction (x z : Nat) (hz : 0 < z) + (hlo : (z - 1) * (z - 1) * (z - 1) ≤ x) + (hhi : x < (z + 1) * (z + 1) * (z + 1)) : + let r := if x / (z * z) < z then z - 1 else z + r * r * r ≤ x ∧ x < (r + 1) * (r + 1) * (r + 1) := by + simp only + have hzz : 0 < z * z := Nat.mul_pos hz hz + by_cases h_lt : x / (z * z) < z + · -- x/(z²) < z means z³ > x + simp [h_lt] + have h_zcube : x < z * z * z := by + have h_euc := Nat.div_add_mod x (z * z) + have h_mod := Nat.mod_lt x hzz + have h1 : x < (z * z) * (x / (z * z) + 1) := by rw [Nat.mul_add, Nat.mul_one]; omega + have h2 : x / (z * z) + 1 ≤ z := by omega + calc x < z * z * (x / (z * z) + 1) := h1 + _ ≤ z * z * z := Nat.mul_le_mul_left (z * z) h2 + constructor + · exact hlo + · have : z - 1 + 1 = z := by omega + rw [this]; exact h_zcube + · -- x/(z²) ≥ z means z³ ≤ x + simp [h_lt] + simp only [Nat.not_lt] at h_lt + have h_zcube : z * z * z ≤ x := by + have h_div_le : z * z * (x / (z * z)) ≤ x := Nat.mul_div_le x (z * z) + calc z * z * z + ≤ z * z * (x / (z * z)) := Nat.mul_le_mul_left (z * z) h_lt + _ ≤ x := h_div_le + exact ⟨h_zcube, hhi⟩ + +/-- If `innerCbrt` is bracketed by ±1 around the true floor root, floor correction returns `icbrt`. -/ +private theorem floorCbrt_eq_icbrt_of_bounds (x : Nat) + (hz : 0 < innerCbrt x) + (hlo : (innerCbrt x - 1) * (innerCbrt x - 1) * (innerCbrt x - 1) ≤ x) + (hhi : x < (innerCbrt x + 1) * (innerCbrt x + 1) * (innerCbrt x + 1)) : + floorCbrt x = icbrt x := by + let r := if x / (innerCbrt x * innerCbrt x) < innerCbrt x then innerCbrt x - 1 else innerCbrt x + have hcorr : r * r * r ≤ x ∧ x < (r + 1) * (r + 1) * (r + 1) := by + simpa [r] using cbrt_floor_correction x (innerCbrt x) hz hlo hhi + have hr : floorCbrt x = r := by + unfold floorCbrt + rfl + exact hr.trans (icbrt_eq_of_bounds x r hcorr.1 hcorr.2) + +/-- End-to-end floor correctness, with the remaining upper-bound link explicit. -/ +theorem floorCbrt_correct_of_upper (x : Nat) (hx : 0 < x) + (hupper : innerCbrt x ≤ icbrt x + 1) : + floorCbrt x = icbrt x := by + have hz := innerCbrt_pos x hx + have hlo := innerCbrt_pred_cube_le_of_upper x hupper + have hhi := innerCbrt_lt_succ_cube x hx + exact floorCbrt_eq_icbrt_of_bounds x hz hlo hhi diff --git a/formal/cbrt/CbrtProof/CbrtProof/CertifiedChain.lean b/formal/cbrt/CbrtProof/CbrtProof/CertifiedChain.lean new file mode 100644 index 000000000..63ebcc7ca --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof/CertifiedChain.lean @@ -0,0 +1,372 @@ +/- + Certified chain: 6 cbrt NR steps with per-octave error tracking. + + Given a certificate octave i with: + - lo ≤ m ≤ hi (bounds on icbrt(x)) + - seed = cbrt seed for the octave + - d1..d6 error bounds with d6 ≤ 1 + + We prove: run6From x (seedOf i) ≤ m + 1. + + The proof chains: + Step 1: d1 bound from analytic formula (cbrt_d1_bound) + Steps 2-6: each step contracts via cbrtStep_upper_of_le + relaxation to lo +-/ +import Init +import CbrtProof.FloorBound +import CbrtProof.CbrtCorrect +import CbrtProof.FiniteCert + +namespace CbrtCertified + +open CbrtCert + +-- ============================================================================ +-- Pure polynomial identities (no subtraction) +-- ============================================================================ + +/-- d²(d+3s) + 3(d+s)s² = (d+s)³ + 2s³ -/ +private theorem poly_id_ge (d s : Nat) : + d * d * (d + s + 2 * s) + 3 * (d + s) * (s * s) + = (d + s) * (d + s) * (d + s) + 2 * (s * s * s) := by + simp only [Nat.add_mul, Nat.mul_add, Nat.mul_assoc, + Nat.mul_comm, Nat.mul_left_comm, Nat.add_assoc, Nat.add_left_comm] + have h1 : d * (d * (s * 2)) = 2 * (d * (d * s)) := by + rw [show d * (s * 2) = (d * s) * 2 from by rw [← Nat.mul_assoc]] + rw [show d * ((d * s) * 2) = (d * (d * s)) * 2 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h2 : d * (s * (s * 3)) = 3 * (d * (s * s)) := by + rw [show s * (s * 3) = (s * s) * 3 from by rw [← Nat.mul_assoc]] + rw [show d * ((s * s) * 3) = (d * (s * s)) * 3 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h3 : s * (s * (s * 2)) = 2 * (s * (s * s)) := by + rw [show s * (s * 2) = (s * s) * 2 from by rw [← Nat.mul_assoc]] + rw [show s * ((s * s) * 2) = (s * (s * s)) * 2 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h4 : s * (s * (s * 3)) = 3 * (s * (s * s)) := by + rw [show s * (s * 3) = (s * s) * 3 from by rw [← Nat.mul_assoc]] + rw [show s * ((s * s) * 3) = (s * (s * s)) * 3 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + omega + +/-- d²(m+2(d+m)) + 3m(d+m)² = m³ + 2(d+m)³ -/ +private theorem poly_id_le (d m : Nat) : + d * d * (m + 2 * (d + m)) + 3 * m * ((d + m) * (d + m)) + = m * m * m + 2 * ((d + m) * (d + m) * (d + m)) := by + simp only [Nat.add_mul, Nat.mul_add, Nat.mul_assoc, + Nat.mul_comm, Nat.mul_left_comm, Nat.add_assoc, Nat.add_left_comm] + have h1 : d * (d * (m * 2)) = 2 * (d * (d * m)) := by + rw [show d * (m * 2) = (d * m) * 2 from by rw [← Nat.mul_assoc]] + rw [show d * ((d * m) * 2) = (d * (d * m)) * 2 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h2 : d * (m * (m * 2)) = 2 * (d * (m * m)) := by + rw [show m * (m * 2) = (m * m) * 2 from by rw [← Nat.mul_assoc]] + rw [show d * ((m * m) * 2) = (d * (m * m)) * 2 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h3 : d * (d * (d * 2)) = 2 * (d * (d * d)) := by + rw [show d * (d * 2) = (d * d) * 2 from by rw [← Nat.mul_assoc]] + rw [show d * ((d * d) * 2) = (d * (d * d)) * 2 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h4 : m * (m * (m * 2)) = 2 * (m * (m * m)) := by + rw [show m * (m * 2) = (m * m) * 2 from by rw [← Nat.mul_assoc]] + rw [show m * ((m * m) * 2) = (m * (m * m)) * 2 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h5 : m * (m * (m * 3)) = 3 * (m * (m * m)) := by + rw [show m * (m * 3) = (m * m) * 3 from by rw [← Nat.mul_assoc]] + rw [show m * ((m * m) * 3) = (m * (m * m)) * 3 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h6 : d * (m * (m * 3)) = 3 * (d * (m * m)) := by + rw [show m * (m * 3) = (m * m) * 3 from by rw [← Nat.mul_assoc]] + rw [show d * ((m * m) * 3) = (d * (m * m)) * 3 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + have h7 : d * (d * (m * 3)) = 3 * (d * (d * m)) := by + rw [show d * (m * 3) = (d * m) * 3 from by rw [← Nat.mul_assoc]] + rw [show d * ((d * m) * 3) = (d * (d * m)) * 3 from by rw [← Nat.mul_assoc]] + rw [Nat.mul_comm] + omega + +-- ============================================================================ +-- Step-from-bound: one NR step error bound using lo as denominator +-- ============================================================================ + +/-- One NR step with certificate denominator. + If z ∈ [m, m+D] and 2D ≤ m and lo ≤ m, then cbrtStep(x, z) - m ≤ D²/lo + 1. + Relaxes cbrtStep_upper_of_le from D²/m to D²/lo. -/ +theorem step_from_bound + (x m lo z D : Nat) + (hm2 : 2 ≤ m) + (hloPos : 0 < lo) + (hlo : lo ≤ m) + (hxhi : x < (m + 1) * (m + 1) * (m + 1)) + (hmz : m ≤ z) + (hzD : z - m ≤ D) + (h2D : 2 * D ≤ m) : + cbrtStep x z - m ≤ nextD lo D := by + have hzD' : z ≤ m + D := by omega + have hstep : cbrtStep x z ≤ m + (D * D / m) + 1 := + cbrtStep_upper_of_le x m z D hm2 hmz hzD' h2D hxhi + have hDm : D * D / m ≤ D * D / lo := + Nat.div_le_div_left hlo hloPos + have hle : cbrtStep x z ≤ m + (D * D / lo) + 1 := + Nat.le_trans hstep (Nat.add_le_add_right (Nat.add_le_add_left hDm m) 1) + -- Goal: cbrtStep x z - m ≤ D * D / lo + 1 + -- From hle: cbrtStep x z ≤ m + (D * D / lo) + 1 = (D * D / lo + 1) + m + -- By Nat.sub_le_of_le_add: cbrtStep x z - m ≤ D * D / lo + 1 + unfold nextD + exact Nat.sub_le_of_le_add (by omega : cbrtStep x z ≤ (D * D / lo + 1) + m) + +-- ============================================================================ +-- First-step (d1) bound: analytic formula via cubic identity +-- ============================================================================ + +/-- Witness identity for m ≥ s: + (m-s)²(m+2s) + 3ms² = m³+2s³. + Used to prove AM-GM and the d1 bound. -/ +private theorem cubic_witness_ge (m s : Nat) (h : s ≤ m) : + (m - s) * (m - s) * (m + 2 * s) + 3 * m * (s * s) + = m * m * m + 2 * (s * s * s) := by + generalize hd : m - s = d + have hm : m = d + s := by omega + rw [hm] + exact poly_id_ge d s + +/-- Witness identity for s > m: + (s-m)²(m+2s) + 3ms² = m³+2s³. -/ +private theorem cubic_witness_le (m s : Nat) (h : m ≤ s) : + (s - m) * (s - m) * (m + 2 * s) + 3 * m * (s * s) + = m * m * m + 2 * (s * s * s) := by + generalize hd : s - m = d + have hs : s = d + m := by omega + rw [hs] + exact poly_id_le d m + +/-- First-step error bound for cbrt NR step. + Uses: 3s²(z₁ - m) ≤ (m-s)²(m+2s) + 3m(m+1) ≤ maxAbs²(hi+2s) + 3hi(hi+1). -/ +theorem cbrt_d1_bound + (x m s lo hi : Nat) + (hs : 0 < s) + (hmlo : m * m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1) * (m + 1)) + (hlo : lo ≤ m) + (hhi : m ≤ hi) : + let maxAbs := max (s - lo) (hi - s) + cbrtStep x s - m ≤ (maxAbs * maxAbs * (hi + 2 * s) + 3 * hi * (hi + 1)) / + (3 * (s * s)) := by + simp only + unfold cbrtStep + -- Floor bound: m ≤ z₁ + have hmstep : m ≤ (x / (s * s) + 2 * s) / 3 := + cbrt_step_floor_bound x s m hs hmlo + have hss : 0 < s * s := Nat.mul_pos hs hs + have h3ss : 0 < 3 * (s * s) := by omega + -- Key bound: 3s²·z₁ ≤ x + 2s³ + -- From: 3·z₁ ≤ ⌊x/s²⌋ + 2s and s²·⌊x/s²⌋ ≤ x. + have h3z1 : 3 * ((x / (s * s) + 2 * s) / 3) ≤ x / (s * s) + 2 * s := + Nat.mul_div_le _ 3 + have hfloor : s * s * (x / (s * s)) ≤ x := Nat.mul_div_le x (s * s) + have h3ssz1 : 3 * (s * s) * ((x / (s * s) + 2 * s) / 3) ≤ x + 2 * (s * s * s) := by + have hmul : s * s * (3 * ((x / (s * s) + 2 * s) / 3)) ≤ + s * s * (x / (s * s) + 2 * s) := + Nat.mul_le_mul_left _ h3z1 + have hexp : s * s * (x / (s * s) + 2 * s) = + s * s * (x / (s * s)) + s * s * (2 * s) := Nat.mul_add _ _ _ + have hexp2 : s * s * (2 * s) = 2 * (s * s * s) := by + rw [Nat.mul_comm 2 s, ← Nat.mul_assoc (s * s) s 2, Nat.mul_comm (s * s * s) 2] + have hcomm : s * s * (3 * ((x / (s * s) + 2 * s) / 3)) = + 3 * (s * s) * ((x / (s * s) + 2 * s) / 3) := by + rw [← Nat.mul_assoc (s * s) 3, Nat.mul_comm (s * s) 3] + rw [← hcomm] + calc s * s * (3 * ((x / (s * s) + 2 * s) / 3)) + ≤ s * s * (x / (s * s) + 2 * s) := hmul + _ = s * s * (x / (s * s)) + s * s * (2 * s) := hexp + _ = s * s * (x / (s * s)) + 2 * (s * s * s) := by rw [hexp2] + _ ≤ x + 2 * (s * s * s) := Nat.add_le_add_right hfloor _ + -- 3s²m ≤ 3s²z₁ + have h3ssm : 3 * (s * s) * m ≤ 3 * (s * s) * ((x / (s * s) + 2 * s) / 3) := + Nat.mul_le_mul_left _ hmstep + -- AM-GM: 3ms² ≤ m³+2s³ ≤ x+2s³ + have ham : 3 * m * (s * s) ≤ x + 2 * (s * s * s) := by + by_cases hsm : s ≤ m + · have := cubic_witness_ge m s hsm; omega + · have := cubic_witness_le m s (by omega); omega + -- 3s²(z₁-m) ≤ (x+2s³) - 3ms² + have hsub : 3 * (s * s) * ((x / (s * s) + 2 * s) / 3 - m) ≤ + x + 2 * (s * s * s) - 3 * m * (s * s) := by + rw [Nat.mul_sub (3 * (s * s)) ((x / (s * s) + 2 * s) / 3) m] + have hcomm2 : 3 * (s * s) * m = 3 * m * (s * s) := by + rw [Nat.mul_assoc 3 (s * s) m, Nat.mul_comm (s * s) m, ← Nat.mul_assoc 3 m (s * s)] + rw [hcomm2] + exact Nat.sub_le_sub_right h3ssz1 _ + -- x+2s³-3ms² ≤ (m³+3m²+3m)+2s³-3ms² + have hxup : x ≤ m * m * m + 3 * (m * m) + 3 * m := by + have : (m + 1) * (m + 1) * (m + 1) = m * m * m + 3 * (m * m) + 3 * m + 1 := by + simp only [Nat.add_mul, Nat.mul_add, Nat.mul_one, Nat.one_mul, Nat.mul_assoc, + Nat.add_assoc] + omega + omega + -- (m³+3m²+3m)+2s³-3ms² = diff²·(m+2s) + 3m(m+1) + -- ≤ maxAbs²·(hi+2s) + 3hi(hi+1) + -- Use Nat.le_div_iff_mul_le to convert goal + rw [Nat.le_div_iff_mul_le h3ss] + -- Goal: (z₁-m)·(3s²) ≤ maxAbs²·(hi+2s) + 3hi(hi+1) + -- Chain through the bounds + let RHS := max (s - lo) (hi - s) * max (s - lo) (hi - s) * (hi + 2 * s) + 3 * hi * (hi + 1) + suffices h : 3 * (s * s) * ((x / (s * s) + 2 * s) / 3 - m) ≤ RHS by + calc ((x / (s * s) + 2 * s) / 3 - m) * (3 * (s * s)) + = 3 * (s * s) * ((x / (s * s) + 2 * s) / 3 - m) := by + rw [Nat.mul_comm] + _ ≤ RHS := h + -- Chain: 3s²(z₁-m) ≤ x+2s³-3ms² ≤ m³+3m²+3m+2s³-3ms² ≤ RHS + have hstep1 : x + 2 * (s * s * s) - 3 * m * (s * s) ≤ + (m * m * m + 3 * (m * m) + 3 * m) + 2 * (s * s * s) - 3 * m * (s * s) := + Nat.sub_le_sub_right (Nat.add_le_add_right hxup _) _ + -- Now bound the cubic difference and the quadratic term + have hcubic : m * m * m + 2 * (s * s * s) - 3 * m * (s * s) ≤ + max (s - lo) (hi - s) * max (s - lo) (hi - s) * (hi + 2 * s) := by + by_cases hsm : s ≤ m + · have hid := cubic_witness_ge m s hsm + have hident : m * m * m + 2 * (s * s * s) - 3 * m * (s * s) = + (m - s) * (m - s) * (m + 2 * s) := by omega + rw [hident] + have hdiff : m - s ≤ hi - s := Nat.sub_le_sub_right hhi s + have hdm : m - s ≤ max (s - lo) (hi - s) := + Nat.le_trans hdiff (Nat.le_max_right _ _) + have hm2s : m + 2 * s ≤ hi + 2 * s := Nat.add_le_add_right hhi _ + exact Nat.mul_le_mul (Nat.mul_le_mul hdm hdm) hm2s + · -- s > m case (¬(s ≤ m) means m < s) + have hsm' : m ≤ s := by omega + have hid := cubic_witness_le m s hsm' + have hident : m * m * m + 2 * (s * s * s) - 3 * m * (s * s) = + (s - m) * (s - m) * (m + 2 * s) := by omega + rw [hident] + have hdiff : s - m ≤ s - lo := Nat.sub_le_sub_left hlo s + have hdm : s - m ≤ max (s - lo) (hi - s) := + Nat.le_trans hdiff (Nat.le_max_left _ _) + have hm2s : m + 2 * s ≤ hi + 2 * s := Nat.add_le_add_right hhi _ + exact Nat.mul_le_mul (Nat.mul_le_mul hdm hdm) hm2s + have hquad : 3 * (m * m) + 3 * m ≤ 3 * hi * (hi + 1) := by + have hmm1 : m * (m + 1) ≤ hi * (hi + 1) := + Nat.mul_le_mul hhi (by omega : m + 1 ≤ hi + 1) + have h3mm : 3 * (m * m) + 3 * m = 3 * (m * (m + 1)) := by + rw [Nat.mul_add m m 1, Nat.mul_one, Nat.mul_add 3 (m * m) m] + have h3hh : 3 * hi * (hi + 1) = 3 * (hi * (hi + 1)) := by + rw [Nat.mul_assoc] + omega + -- AM-GM: 3ms² ≤ m³+2s³ (needed for Nat subtraction safety) + have ham_pure : 3 * m * (s * s) ≤ m * m * m + 2 * (s * s * s) := by + by_cases hsm : s ≤ m + · have := cubic_witness_ge m s hsm; omega + · have := cubic_witness_le m s (by omega); omega + -- Assemble: (m³+3m²+3m)+2s³-3ms² = (m³+2s³-3ms²)+(3m²+3m) [safe since AM-GM] + have hassemble : (m * m * m + 3 * (m * m) + 3 * m) + 2 * (s * s * s) - 3 * m * (s * s) ≤ + RHS := by + -- Rewrite LHS using AM-GM safety and Nat.sub_add_comm + have hadd : (m * m * m + 3 * (m * m) + 3 * m) + 2 * (s * s * s) = + (m * m * m + 2 * (s * s * s)) + (3 * (m * m) + 3 * m) := by omega + rw [hadd, Nat.sub_add_comm ham_pure] + exact Nat.add_le_add hcubic hquad + calc 3 * (s * s) * ((x / (s * s) + 2 * s) / 3 - m) + ≤ x + 2 * (s * s * s) - 3 * m * (s * s) := hsub + _ ≤ (m * m * m + 3 * (m * m) + 3 * m) + 2 * (s * s * s) - 3 * m * (s * s) := hstep1 + _ ≤ RHS := hassemble + +-- ============================================================================ +-- Six-step certified chain +-- ============================================================================ + +/-- Five-step certified chain: z₅ ≥ m, error ≤ d5, and 2*d5 ≤ m. -/ +theorem run5_certified_bounds + (i : Fin 248) (x m : Nat) + (hm2 : 2 ≤ m) + (hmlo : m * m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1) * (m + 1)) + (hlo : loOf i ≤ m) + (hhi : m ≤ hiOf i) : + m ≤ run5From x (seedOf i) ∧ + run5From x (seedOf i) - m ≤ d5Of i ∧ + 2 * d5Of i ≤ m := by + -- Name intermediate values using let + let z1 := cbrtStep x (seedOf i) + let z2 := cbrtStep x z1 + let z3 := cbrtStep x z2 + let z4 := cbrtStep x z3 + let z5 := cbrtStep x z4 + + have hloPos : 0 < loOf i := lo_pos i + have hsPos : 0 < seedOf i := seed_pos i + + -- Lower bounds via floor bound + have hmz1 : m ≤ z1 := cbrt_step_floor_bound x (seedOf i) m hsPos hmlo + have hz1Pos : 0 < z1 := by omega + have hmz2 : m ≤ z2 := cbrt_step_floor_bound x z1 m hz1Pos hmlo + have hz2Pos : 0 < z2 := by omega + have hmz3 : m ≤ z3 := cbrt_step_floor_bound x z2 m hz2Pos hmlo + have hz3Pos : 0 < z3 := by omega + have hmz4 : m ≤ z4 := cbrt_step_floor_bound x z3 m hz3Pos hmlo + have hz4Pos : 0 < z4 := by omega + have hmz5 : m ≤ z5 := cbrt_step_floor_bound x z4 m hz4Pos hmlo + + -- Step 1: d1 bound from analytic formula + have hd1 : z1 - m ≤ d1Of i := by + have h := cbrt_d1_bound x m (seedOf i) (loOf i) (hiOf i) hsPos hmlo hmhi hlo hhi + simp only at h + show cbrtStep x (seedOf i) - m ≤ d1Of i + have hd1eq := d1_eq i + have hmaxeq := maxabs_eq i + rw [hmaxeq] at hd1eq + rw [← hd1eq] at h + exact h + have h2d1 : 2 * d1Of i ≤ m := Nat.le_trans (two_d1_le_lo i) hlo + + -- Steps 2-5 via step_from_bound + have hd2 : z2 - m ≤ d2Of i := by + have h := step_from_bound x m (loOf i) z1 (d1Of i) hm2 hloPos hlo hmhi hmz1 hd1 h2d1 + show cbrtStep x z1 - m ≤ d2Of i + unfold d2Of; exact h + have h2d2 : 2 * d2Of i ≤ m := Nat.le_trans (two_d2_le_lo i) hlo + + have hd3 : z3 - m ≤ d3Of i := by + have h := step_from_bound x m (loOf i) z2 (d2Of i) hm2 hloPos hlo hmhi hmz2 hd2 h2d2 + show cbrtStep x z2 - m ≤ d3Of i + unfold d3Of; exact h + have h2d3 : 2 * d3Of i ≤ m := Nat.le_trans (two_d3_le_lo i) hlo + + have hd4 : z4 - m ≤ d4Of i := by + have h := step_from_bound x m (loOf i) z3 (d3Of i) hm2 hloPos hlo hmhi hmz3 hd3 h2d3 + show cbrtStep x z3 - m ≤ d4Of i + unfold d4Of; exact h + have h2d4 : 2 * d4Of i ≤ m := Nat.le_trans (two_d4_le_lo i) hlo + + have hd5 : z5 - m ≤ d5Of i := by + have h := step_from_bound x m (loOf i) z4 (d4Of i) hm2 hloPos hlo hmhi hmz4 hd4 h2d4 + show cbrtStep x z4 - m ≤ d5Of i + unfold d5Of; exact h + have h2d5 : 2 * d5Of i ≤ m := Nat.le_trans (two_d5_le_lo i) hlo + + exact ⟨hmz5, hd5, h2d5⟩ + +/-- Chain 6 steps through the error recurrence, concluding z₆ ≤ m + 1. -/ +theorem run6_le_m_plus_one + (i : Fin 248) + (x m : Nat) + (hm2 : 2 ≤ m) + (hmlo : m * m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1) * (m + 1)) + (hlo : loOf i ≤ m) + (hhi : m ≤ hiOf i) : + run6From x (seedOf i) ≤ m + 1 := by + have ⟨hmz5, hd5, h2d5⟩ := run5_certified_bounds i x m hm2 hmlo hmhi hlo hhi + -- Step 6 error bound + have hloPos : 0 < loOf i := lo_pos i + have hd6 : cbrtStep x (run5From x (seedOf i)) - m ≤ d6Of i := by + have h := step_from_bound x m (loOf i) (run5From x (seedOf i)) (d5Of i) + hm2 hloPos hlo hmhi hmz5 hd5 h2d5 + unfold d6Of; exact h + -- Terminal: d6 ≤ 1 + have hd6le1 : cbrtStep x (run5From x (seedOf i)) - m ≤ 1 := + Nat.le_trans hd6 (d6_le_one i) + show run6From x (seedOf i) ≤ m + 1 + rw [run6_eq_step_run5] + omega + +end CbrtCertified diff --git a/formal/cbrt/CbrtProof/CbrtProof/FloorBound.lean b/formal/cbrt/CbrtProof/CbrtProof/FloorBound.lean new file mode 100644 index 000000000..bd2565ee4 --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof/FloorBound.lean @@ -0,0 +1,121 @@ +/- + Floor bound for the cube root Newton-Raphson step. + Core: (3m - 2z) * z² ≤ m³ (cubic AM-GM). +-/ +import Init + +-- ============================================================================ +-- Cubic expansion (algebraic identity, proof is mechanical expansion) +-- ============================================================================ + +/-- (d+z)³ = d³ + 3d²z + 3dz² + z³ (left-associated products). -/ +private theorem cube_expand (d z : Nat) : + (d + z) * (d + z) * (d + z) = + d * d * d + 3 * (d * d * z) + 3 * (d * z * z) + z * z * z := by + -- Mechanical expansion of a binomial cube. + -- Both sides equal the sum of 8 triple products, grouped 1+3+3+1. + -- Proof: expand via add_mul/mul_add, normalize with mul_assoc/mul_comm, omega collects. + simp only [Nat.add_mul, Nat.mul_add] + simp only [Nat.mul_assoc] + simp only [Nat.mul_comm z d, Nat.mul_left_comm z d] + omega + +-- ============================================================================ +-- Cubic witness: (3d+z)*z² + d²*(d+3z) = (d+z)³ +-- ============================================================================ + +/-- Both sides expand to d³+3d²z+3dz²+z³. -/ +private theorem cubic_witness (d z : Nat) : + (3 * d + z) * (z * z) + d * d * (d + 3 * z) = (d + z) * (d + z) * (d + z) := by + -- LHS = 3dz² + z³ + d³ + 3d²z = d³ + 3d²z + 3dz² + z³ + rw [Nat.add_mul (3 * d) z (z * z)] + rw [Nat.mul_add (d * d) d (3 * z)] + rw [cube_expand d z] + -- After expansion of LHS and RHS to canonical form, omega matches. + -- Need to normalize: 3*d*(z*z) to 3*(d*z*z) etc. + rw [Nat.mul_assoc 3 d (z * z)] + rw [Nat.mul_comm (d * d) (3 * z), Nat.mul_assoc 3 z (d * d), Nat.mul_comm z (d * d)] + -- LHS: 3*(d*(z*z)) + z*(z*z) + (d*d*d + 3*((d*d)*z)) + -- RHS: d*d*d + 3*(d*d*z) + 3*(d*z*z) + z*z*z + -- The products d*(z*z) vs d*z*z differ in association: d*(z*z) vs (d*z)*z. + simp only [Nat.mul_assoc] + omega + +-- ============================================================================ +-- Cubic AM-GM +-- ============================================================================ + +theorem cubic_identity_le (z m : Nat) (h : z ≤ m) : + (3 * m - 2 * z) * (z * z) + (m - z) * (m - z) * (m + 2 * z) = m * m * m := by + have hd : 3 * m - 2 * z = 3 * (m - z) + z := by omega + have hm2z : m + 2 * z = (m - z) + 3 * z := by omega + rw [hd, hm2z] + -- Need to match (m-z)*(m-z)*((m-z)+3z) with d*d*(d+3z) from cubic_witness + -- cubic_witness gives: (3d+z)*(z*z) + d*d*(d+3z) = (d+z)*(d+z)*(d+z) + -- Our LHS has: (3d+z)*(z*z) + d*d*(d+3z) where d = m-z + -- Our RHS needs: m*m*m = ((m-z)+z)*((m-z)+z)*((m-z)+z) + have key := cubic_witness (m - z) z + rw [Nat.sub_add_cancel h] at key + exact key + +/-- Addition-only witness for the ge case: + a*(a+3b)² + b²*(3a+8b) = (a+2b)³. -/ +private theorem cubic_witness_ge (a b : Nat) : + a * ((a + 3 * b) * (a + 3 * b)) + b * b * (3 * a + 8 * b) = + (a + 2 * b) * (a + 2 * b) * (a + 2 * b) := by + -- Eliminate numeric constants by converting to repeated addition. + -- This ensures simp only sees pure products of a and b. + rw [show 3 * b = b + (b + b) from by omega] + rw [show 3 * a = a + (a + a) from by omega] + rw [show 8 * b = b + (b + (b + (b + (b + (b + (b + b)))))) from by omega] + rw [show 2 * b = b + b from by omega] + -- Now distribute, right-associate, sort variables, collect. + simp only [Nat.add_mul, Nat.mul_add] + simp only [Nat.mul_assoc] + simp only [Nat.mul_comm b a, Nat.mul_left_comm b a] + omega + +theorem cubic_identity_ge (z m : Nat) (h1 : m ≤ z) (h2 : 2 * z ≤ 3 * m) : + (3 * m - 2 * z) * (z * z) + (z - m) * (z - m) * (m + 2 * z) = m * m * m := by + -- Specialize cubic_witness_ge with a = 3m-2z, b = z-m. + -- Then a+3b = z, 3a+8b = m+2z, a+2b = m. + have key := cubic_witness_ge (3 * m - 2 * z) (z - m) + have h3 : 3 * m - 2 * z + 3 * (z - m) = z := by omega + have h4 : 3 * (3 * m - 2 * z) + 8 * (z - m) = m + 2 * z := by omega + have h5 : 3 * m - 2 * z + 2 * (z - m) = m := by omega + rw [h3, h4, h5] at key + exact key + +theorem cubic_am_gm (z m : Nat) : (3 * m - 2 * z) * (z * z) ≤ m * m * m := by + by_cases h : z ≤ m + · have := cubic_identity_le z m h; omega + · simp only [Nat.not_le] at h + by_cases h2 : 2 * z ≤ 3 * m + · have := cubic_identity_ge z m (Nat.le_of_lt h) h2; omega + · simp only [Nat.not_le] at h2 + simp [Nat.sub_eq_zero_of_le (Nat.le_of_lt h2)] + +-- ============================================================================ +-- Floor Bound +-- ============================================================================ + +/-- +**Floor Bound for cube root Newton-Raphson.** + +For any `m` with `m³ ≤ x`, and `z > 0`: + m ≤ (x / (z * z) + 2 * z) / 3 + +A single truncated NR step for cube root never undershoots `icbrt(x)`. +-/ +theorem cbrt_step_floor_bound (x z m : Nat) (hz : 0 < z) (hm : m * m * m ≤ x) : + m ≤ (x / (z * z) + 2 * z) / 3 := by + have hzz : 0 < z * z := Nat.mul_pos hz hz + rw [Nat.le_div_iff_mul_le (by omega : (0 : Nat) < 3)] + -- 3*m ≤ x/(z*z) + 2*z iff 3*m - 2*z ≤ x/(z*z) (when 3m ≥ 2z) + -- iff (3m - 2z) * (z*z) ≤ x (by le_div_iff) + -- And (3m-2z)*(z*z) ≤ m³ ≤ x (by cubic_am_gm + hm) + suffices h : 3 * m - 2 * z ≤ x / (z * z) by omega + rw [Nat.le_div_iff_mul_le hzz] + calc (3 * m - 2 * z) * (z * z) + ≤ m * m * m := cubic_am_gm z m + _ ≤ x := hm diff --git a/formal/cbrt/CbrtProof/CbrtProof/GeneratedCbrtSpec.lean b/formal/cbrt/CbrtProof/CbrtProof/GeneratedCbrtSpec.lean new file mode 100644 index 000000000..a3cd985f3 --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof/GeneratedCbrtSpec.lean @@ -0,0 +1,965 @@ +/- + Bridge: auto-generated Lean model of Cbrt.sol ↔ proven-correct hand-written spec. + + Levels: + 1. Nat model = hand-written spec (normStep, normSeed, model_cbrt ↔ innerCbrt) + 2. EVM model = Nat model (no overflow on uint256) + 3. Floor correction (model_cbrt_floor_evm = floorCbrt = icbrt) + 4. cbrtUp rounding (model_cbrt_up_evm rounds up correctly) +-/ +import Init +import CbrtProof.GeneratedCbrtModel +import CbrtProof.CbrtCorrect +import CbrtProof.CertifiedChain +import CbrtProof.FiniteCert +import CbrtProof.Wiring +import CbrtProof.OverflowSafety + +set_option exponentiation.threshold 300 + +namespace CbrtGeneratedModel + +open CbrtGeneratedModel +open CbrtCertified +open CbrtCert +open CbrtWiring + +-- ============================================================================ +-- Level 1: Nat model = hand-written spec +-- ============================================================================ + +/-- One NR step in the generated model unfolds to cbrtStep. -/ +private theorem normStep_eq_cbrtStep (x z : Nat) : + normDiv (normAdd (normAdd (normDiv x (normMul z z)) z) z) 3 = cbrtStep x z := by + simp [normDiv, normAdd, normMul, cbrtStep] + omega + +/-- The normalized seed expression (using bitLengthPlus1) equals cbrtSeed for all positive x. + No uint256 bound required: normBitLengthPlus1 computes log2(x) + 2 directly. -/ +private theorem normSeed_eq_cbrtSeed_of_pos + (x : Nat) (hx : 0 < x) : + normAdd 1 (normShr 8 (normShl (normDiv (normBitLengthPlus1 x) 3) 233)) = + cbrtSeed x := by + unfold normAdd normShr normShl normDiv normBitLengthPlus1 cbrtSeed + simp [Nat.ne_of_gt hx, Nat.shiftLeft_eq, Nat.shiftRight_eq_div_pow] + omega + +/-- Bridge: the old sub(257, clz(x)) seed expression equals cbrtSeed for x < 2^256. + Only used to connect the EVM model (which still uses sub/clz) to the Nat model. -/ +private theorem normSub257Clz_eq_cbrtSeed_of_pos + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + normAdd 1 (normShr 8 (normShl (normDiv (normSub 257 (normClz x)) 3) 233)) = + cbrtSeed x := by + unfold normAdd normShr normShl normDiv normSub normClz cbrtSeed + simp [Nat.ne_of_gt hx, Nat.shiftLeft_eq, Nat.shiftRight_eq_div_pow] + have hlog : Nat.log2 x < 256 := (Nat.log2_lt (Nat.ne_of_gt hx)).2 hx256 + have hsub : 257 - (255 - Nat.log2 x) = Nat.log2 x + 2 := by omega + rw [hsub] + omega + +/-- model_cbrt 0 = 0 -/ +private theorem model_cbrt_zero : model_cbrt 0 = 0 := by + simp [model_cbrt, normAdd, normShr, normShl, normDiv, normBitLengthPlus1, normMul] + +/-- For all x, model_cbrt x = innerCbrt x. No uint256 bound required. -/ +theorem model_cbrt_eq_innerCbrt (x : Nat) : + model_cbrt x = innerCbrt x := by + by_cases hx0 : x = 0 + · subst hx0; decide + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + have hseed : normAdd 1 (normShr 8 (normShl (normDiv (normBitLengthPlus1 x) 3) 233)) = cbrtSeed x := + normSeed_eq_cbrtSeed_of_pos x hx + unfold model_cbrt innerCbrt + simp [hseed, normStep_eq_cbrtStep] + +-- ============================================================================ +-- Level 1.5: Bracket result for Nat model +-- ============================================================================ + +theorem model_cbrt_bracket_u256_all + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + let m := icbrt x + m ≤ model_cbrt x ∧ model_cbrt x ≤ m + 1 := by + rw [model_cbrt_eq_innerCbrt x] + constructor + · exact innerCbrt_lower x (icbrt x) hx (icbrt_cube_le x) + · exact innerCbrt_upper_u256 x hx hx256 + +-- ============================================================================ +-- Level 2: EVM helpers +-- ============================================================================ + +private theorem word_mod_gt_256 : 256 < WORD_MOD := by + unfold WORD_MOD; decide + +private theorem u256_eq_of_lt (x : Nat) (hx : x < WORD_MOD) : u256 x = x := by + unfold u256 + exact Nat.mod_eq_of_lt hx + +private theorem evmClz_eq_normClz_of_u256 (x : Nat) (hx : x < WORD_MOD) : + evmClz x = normClz x := by + unfold evmClz normClz + simp [u256_eq_of_lt x hx] + +private theorem normClz_le_256 (x : Nat) : normClz x ≤ 256 := by + unfold normClz; split <;> omega + +private theorem evmSub_eq_normSub_of_le + (a b : Nat) (ha : a < WORD_MOD) (hb : b ≤ a) : + evmSub a b = normSub a b := by + have hb' : b < WORD_MOD := Nat.lt_of_le_of_lt hb ha + have hab' : a - b < WORD_MOD := Nat.lt_of_le_of_lt (Nat.sub_le a b) ha + unfold evmSub normSub + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb'] + have hsplit : a + WORD_MOD - b = WORD_MOD + (a - b) := by omega + unfold u256 + rw [hsplit, Nat.add_mod, Nat.mod_eq_zero_of_dvd (Nat.dvd_refl WORD_MOD), Nat.zero_add] + simp [Nat.mod_eq_of_lt hab'] + +private theorem evmDiv_eq_normDiv_of_u256 + (x z : Nat) (hx : x < WORD_MOD) (hz : z < WORD_MOD) : + evmDiv x z = normDiv x z := by + by_cases hz0 : z = 0 + · subst hz0; unfold evmDiv normDiv u256; simp + · unfold evmDiv normDiv + rw [u256_eq_of_lt x hx, u256_eq_of_lt z hz] + simp [hz0] + +private theorem evmAdd_eq_normAdd_of_no_overflow + (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) (hab : a + b < WORD_MOD) : + evmAdd a b = normAdd a b := by + unfold evmAdd normAdd + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb, u256_eq_of_lt (a + b) hab] + +private theorem evmLt_eq_normLt_of_u256 + (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) : + evmLt a b = normLt a b := by + unfold evmLt normLt; simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb] + +private theorem evmShr_eq_normShr_of_u256 + (s v : Nat) (hs : s < 256) (hv : v < WORD_MOD) : + evmShr s v = normShr s v := by + unfold evmShr normShr + have hs' : s < WORD_MOD := Nat.lt_of_lt_of_le hs (Nat.le_of_lt word_mod_gt_256) + simp [u256_eq_of_lt s hs', u256_eq_of_lt v hv, hs] + +private theorem evmShl_eq_normShl_of_safe + (s v : Nat) (hs : s < 256) (hv : v < WORD_MOD) (hvs : v * 2 ^ s < WORD_MOD) : + evmShl s v = normShl s v := by + unfold evmShl normShl + have hs' : s < WORD_MOD := Nat.lt_of_lt_of_le hs (Nat.le_of_lt word_mod_gt_256) + simp [u256_eq_of_lt s hs', u256_eq_of_lt v hv, hs, Nat.shiftLeft_eq] + exact u256_eq_of_lt (v * 2 ^ s) hvs + +private theorem evmMul_eq_normMul_of_no_overflow + (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) (hab : a * b < WORD_MOD) : + evmMul a b = normMul a b := by + unfold evmMul normMul + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb, u256_eq_of_lt (a * b) hab] + +private theorem two_pow_lt_word (n : Nat) (hn : n < 256) : + 2 ^ n < WORD_MOD := by + unfold WORD_MOD + have hn_le : n ≤ 255 := by omega + have hle : 2 ^ n ≤ 2 ^ 255 := + Nat.pow_le_pow_right (by decide : 1 ≤ (2 : Nat)) hn_le + have hlt : 2 ^ 255 < 2 ^ 256 := by + simp [Nat.pow_lt_pow_succ (a := 2) (n := 255) (by decide : 1 < (2 : Nat))] + exact Nat.lt_of_le_of_lt hle hlt + +private theorem one_lt_word : (1 : Nat) < WORD_MOD := by + unfold WORD_MOD; decide + +private theorem three_lt_word : (3 : Nat) < WORD_MOD := by + unfold WORD_MOD; decide + +-- ============================================================================ +-- Level 2: Key bounds for no-overflow +-- ============================================================================ + +-- m = icbrt(x) < 2^86 when x < 2^256 +private theorem m_lt_pow86_of_u256 + (m x : Nat) (hmlo : m * m * m ≤ x) (hx : x < WORD_MOD) : + m < 2 ^ 86 := by + by_cases hm86 : m < 2 ^ 86 + · exact hm86 + · have hmGe : 2 ^ 86 ≤ m := Nat.le_of_not_lt hm86 + have h86sq : (2 ^ 86) * (2 ^ 86) ≤ m * m := Nat.mul_le_mul hmGe hmGe + have h86cube : (2 ^ 86) * (2 ^ 86) * (2 ^ 86) ≤ m * m * m := + Nat.mul_le_mul h86sq hmGe + have hpow_eq : (2 ^ 86) * (2 ^ 86) * (2 ^ 86) = 2 ^ 258 := by + calc (2 ^ 86) * (2 ^ 86) * (2 ^ 86) + = 2 ^ (86 + 86) * (2 ^ 86) := by rw [← Nat.pow_add] + _ = 2 ^ (86 + 86 + 86) := by rw [← Nat.pow_add] + _ = 2 ^ 258 := by decide + have hxGe : 2 ^ 258 ≤ x := by omega + have hword : WORD_MOD ≤ 2 ^ 258 := by + unfold WORD_MOD + exact Nat.pow_le_pow_right (by decide : 1 ≤ 2) (by decide : 256 ≤ 258) + exact False.elim ((Nat.not_lt_of_ge (Nat.le_trans hword hxGe)) hx) + +-- Overflow bound: x/(z*z) + 2*z < WORD_MOD when z ≤ 2m and m < 2^86 +private theorem cbrt_sum_lt_word_of_bounds + (x m z : Nat) + (hx : x < WORD_MOD) + (hm : 0 < m) + (hmlo : m * m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1) * (m + 1)) + (hmz : m ≤ z) + (hz2m : z ≤ 2 * m) : + x / (z * z) + 2 * z < WORD_MOD := by + have hm86 : m < 2 ^ 86 := m_lt_pow86_of_u256 m x hmlo hx + have hmm : 0 < m * m := Nat.mul_pos hm hm + -- x/(z*z) ≤ x/(m*m) + have hdiv_mono : x / (z * z) ≤ x / (m * m) := + Nat.div_le_div_left (Nat.mul_le_mul hmz hmz) hmm + -- x ≤ m^3 + 3m^2 + 3m (from x < (m+1)^3) + have hxle : x ≤ m * m * m + 3 * (m * m) + 3 * m := by + have : (m + 1) * (m + 1) * (m + 1) = m * m * m + 3 * (m * m) + 3 * m + 1 := by + simp only [Nat.add_mul, Nat.mul_add, Nat.mul_one, Nat.one_mul, Nat.mul_assoc, Nat.add_assoc] + omega + omega + have hdiv_bound : x / (m * m) ≤ m + 6 := by + -- m^3 + 3m^2 + 3m ≤ (m*m) * (m + 6) since: + -- (m*m) * (m+6) = m^3 + 6m^2 ≥ m^3 + 3m^2 + 3m (when 3m^2 ≥ 3m, i.e., m ≥ 1) + have h1 : m * m * m + 3 * (m * m) + 3 * m ≤ (m * m) * (m + 6) := by + -- (m*m)*(m+6) = m*m*m + 6*(m*m) + have hexpand : (m * m) * (m + 6) = m * m * m + 6 * (m * m) := by + rw [Nat.mul_add, Nat.mul_comm (m * m) 6] + rw [hexpand] + -- Need: 3*(m*m) + 3*m ≤ 6*(m*m), i.e., 3*m ≤ 3*(m*m), i.e., m ≤ m*m + have hmm_ge : m ≤ m * m := by + calc m = m * 1 := by omega + _ ≤ m * m := Nat.mul_le_mul_left m (Nat.succ_le_of_lt hm) + omega + exact Nat.le_trans (Nat.div_le_div_right hxle) (Nat.div_le_of_le_mul h1) + have hdiv : x / (z * z) ≤ m + 6 := Nat.le_trans hdiv_mono hdiv_bound + have hbound : 5 * (2 ^ 86) + 6 < WORD_MOD := by unfold WORD_MOD; decide + omega + +-- z * z < WORD_MOD when z < 2^87 +private theorem zsq_lt_word_of_lt_87 (z : Nat) (hz : z < 2 ^ 87) : + z * z < WORD_MOD := by + by_cases hz0 : z = 0 + · subst hz0; unfold WORD_MOD; decide + · have hzPos : 0 < z := Nat.pos_of_ne_zero hz0 + -- z * z < 2^87 * z (from z < 2^87 and z > 0) + have h1 : z * z < 2 ^ 87 * z := Nat.mul_lt_mul_of_pos_right hz hzPos + -- 2^87 * z < 2^87 * 2^87 (from z < 2^87 and 2^87 > 0) + have h2 : 2 ^ 87 * z < 2 ^ 87 * 2 ^ 87 := + Nat.mul_lt_mul_of_pos_left hz (Nat.two_pow_pos 87) + have hpow : 2 ^ 87 * 2 ^ 87 = 2 ^ 174 := by rw [← Nat.pow_add] + have h174 : 2 ^ 174 < WORD_MOD := two_pow_lt_word 174 (by decide) + omega + +-- One cbrt step: EVM = Nat when no overflow +private theorem step_evm_eq_norm_of_safe + (x z : Nat) + (hx : x < WORD_MOD) + (_hzPos : 0 < z) + (hz : z < WORD_MOD) + (hzzW : z * z < WORD_MOD) + (hsum : x / (z * z) + 2 * z < WORD_MOD) : + evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z z)) z) z) 3 = + normDiv (normAdd (normAdd (normDiv x (normMul z z)) z) z) 3 := by + -- evmMul z z = normMul z z + have hmul : evmMul z z = normMul z z := + evmMul_eq_normMul_of_no_overflow z z hz hz hzzW + have hmulLt : normMul z z < WORD_MOD := by simpa [normMul] using hzzW + -- evmDiv x (evmMul z z) = normDiv x (normMul z z) + have hdiv1 : evmDiv x (evmMul z z) = normDiv x (normMul z z) := by + rw [hmul]; exact evmDiv_eq_normDiv_of_u256 x (normMul z z) hx hmulLt + have hdivVal : normDiv x (normMul z z) = x / (z * z) := by simp [normDiv, normMul] + have hdivLt : normDiv x (normMul z z) < WORD_MOD := by + rw [hdivVal]; exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hx + -- First add: (x/(z*z)) + z + have haddLt1 : x / (z * z) + z < WORD_MOD := by + have : x / (z * z) + z ≤ x / (z * z) + 2 * z := by omega + exact Nat.lt_of_le_of_lt this hsum + have hadd1 : evmAdd (evmDiv x (evmMul z z)) z = normAdd (normDiv x (normMul z z)) z := by + rw [hdiv1] + exact evmAdd_eq_normAdd_of_no_overflow (normDiv x (normMul z z)) z hdivLt hz + (by simpa [normAdd, hdivVal] using haddLt1) + -- Second add: (x/(z*z) + z) + z + have hadd1Val : normAdd (normDiv x (normMul z z)) z = x / (z * z) + z := by + simp [normAdd, hdivVal] + have hadd1Lt : normAdd (normDiv x (normMul z z)) z < WORD_MOD := by + rw [hadd1Val]; exact haddLt1 + have hsum2 : normAdd (normDiv x (normMul z z)) z + z < WORD_MOD := by + rw [hadd1Val]; have : x / (z * z) + z + z = x / (z * z) + 2 * z := by omega + omega + have hadd2 : evmAdd (evmAdd (evmDiv x (evmMul z z)) z) z = + normAdd (normAdd (normDiv x (normMul z z)) z) z := by + rw [hadd1] + exact evmAdd_eq_normAdd_of_no_overflow + (normAdd (normDiv x (normMul z z)) z) z hadd1Lt hz + (by simpa [normAdd] using hsum2) + -- Division by 3 + have hsumLt : normAdd (normAdd (normDiv x (normMul z z)) z) z < WORD_MOD := by + simp [normAdd, hdivVal] + have : x / (z * z) + z + z = x / (z * z) + 2 * z := by omega + omega + rw [hadd2] + exact evmDiv_eq_normDiv_of_u256 + (normAdd (normAdd (normDiv x (normMul z z)) z) z) 3 hsumLt three_lt_word + +-- Seed: EVM = Nat +private theorem seed_evm_eq_norm (x : Nat) (hx : x < WORD_MOD) : + evmAdd 1 (evmShr 8 (evmShl (evmDiv (evmSub 257 (evmClz x)) 3) 233)) = + normAdd 1 (normShr 8 (normShl (normDiv (normSub 257 (normClz x)) 3) 233)) := by + have hclz : evmClz x = normClz x := evmClz_eq_normClz_of_u256 x hx + have hclzLe : normClz x ≤ 256 := normClz_le_256 x + -- evmSub 257 (evmClz x) = normSub 257 (normClz x) + have h257W : (257 : Nat) < WORD_MOD := by unfold WORD_MOD; decide + have hclzLe257 : normClz x ≤ 257 := by omega + have hsub : evmSub 257 (evmClz x) = normSub 257 (normClz x) := by + simpa [hclz] using evmSub_eq_normSub_of_le 257 (normClz x) h257W hclzLe257 + have hsubLe : normSub 257 (normClz x) ≤ 257 := by unfold normSub; exact Nat.sub_le _ _ + have hsubLt : normSub 257 (normClz x) < WORD_MOD := Nat.lt_of_le_of_lt hsubLe h257W + -- evmDiv (...) 3 = normDiv (...) 3 + have hdiv : evmDiv (evmSub 257 (evmClz x)) 3 = normDiv (normSub 257 (normClz x)) 3 := by + simpa [hsub] using evmDiv_eq_normDiv_of_u256 (normSub 257 (normClz x)) 3 hsubLt three_lt_word + -- q := normDiv result ≤ 85 + have hdivLe : normDiv (normSub 257 (normClz x)) 3 ≤ 85 := by + unfold normDiv; exact Nat.le_trans (Nat.div_le_div_right hsubLe) (by decide) + have hdivLt256 : normDiv (normSub 257 (normClz x)) 3 < 256 := by omega + -- evmShl q 233: shift = q, value = 233 + -- Need: 233 * 2^q < WORD_MOD + have h233W : (233 : Nat) < WORD_MOD := by unfold WORD_MOD; decide + let q := normDiv (normSub 257 (normClz x)) 3 + have hqLt : q < 256 := hdivLt256 + have hshlSafe : 233 * 2 ^ q < WORD_MOD := by + have hq_le_85 : q ≤ 85 := hdivLe + -- 233 < 256 = 2^8, so 233 * 2^85 < 2^8 * 2^85 = 2^93 < 2^256 + have h1 : 233 * 2 ^ q ≤ 233 * 2 ^ 85 := + Nat.mul_le_mul_left 233 (Nat.pow_le_pow_right (by decide : 1 ≤ 2) hq_le_85) + have h2 : 233 * 2 ^ 85 < 2 ^ 94 := by decide + have h3 : 2 ^ 94 < WORD_MOD := two_pow_lt_word 94 (by decide) + omega + have hshl : evmShl (evmDiv (evmSub 257 (evmClz x)) 3) 233 = + normShl (normDiv (normSub 257 (normClz x)) 3) 233 := by + rw [hdiv] + exact evmShl_eq_normShl_of_safe q 233 hqLt h233W hshlSafe + -- normShl result < WORD_MOD + have hshlVal : normShl q 233 < WORD_MOD := by + unfold normShl; rw [Nat.shiftLeft_eq]; exact hshlSafe + -- evmShr 8 (...) = normShr 8 (...) + have hshr : evmShr 8 (evmShl (evmDiv (evmSub 257 (evmClz x)) 3) 233) = + normShr 8 (normShl q 233) := by + rw [hshl] + exact evmShr_eq_normShr_of_u256 8 (normShl q 233) (by decide) hshlVal + -- shr result < WORD_MOD + have hshrLt : normShr 8 (normShl q 233) < WORD_MOD := by + unfold normShr; exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hshlVal + -- sum < WORD_MOD: 1 + shr result < 2^86 + 1 < WORD_MOD + have hshr_bound : normShr 8 (normShl q 233) < 2 ^ 86 := by + unfold normShr normShl + rw [Nat.shiftLeft_eq] + have h1 : 233 * 2 ^ q / 2 ^ 8 ≤ 233 * 2 ^ 85 / 2 ^ 8 := + Nat.div_le_div_right (Nat.mul_le_mul_left 233 + (Nat.pow_le_pow_right (by decide : 1 ≤ 2) hdivLe)) + have h2 : 233 * 2 ^ 85 / 2 ^ 8 < 2 ^ 86 := by decide + omega + have hsum : 1 + normShr 8 (normShl q 233) < WORD_MOD := by + have h86 : 2 ^ 86 + 1 < WORD_MOD := by unfold WORD_MOD; decide + omega + rw [hshr] + exact evmAdd_eq_normAdd_of_no_overflow + 1 (normShr 8 (normShl q 233)) one_lt_word hshrLt + (by simpa [normAdd] using hsum) + +-- ============================================================================ +-- Level 2: Full EVM = Nat model +-- ============================================================================ + +set_option maxRecDepth 1000000 in +-- Seed squared fits in uint256 for every certificate octave. +private theorem seed_sq_lt_word : ∀ i : Fin 248, + seedOf i * seedOf i < WORD_MOD := by decide + +-- The seed NR step numerator fits in uint256 for every certificate octave. +-- For octave i (bit-length i+8), x < 2^(i+9), so: +-- x/(seed²) + 2*seed ≤ (2^(i+9)-1)/(seed²) + 2*seed < WORD_MOD +set_option maxRecDepth 1000000 in +private theorem seed_sum_lt_word : ∀ i : Fin 248, + (2 ^ (i.val + certOffset + 1) - 1) / (seedOf i * seedOf i) + 2 * seedOf i < WORD_MOD := by + decide + +set_option maxRecDepth 1000000 in +-- Small x: model_cbrt_evm = model_cbrt for all x < 256. +private theorem small_cbrt_evm_eq : ∀ v : Fin 256, + model_cbrt_evm v.val = model_cbrt v.val := by decide + +theorem model_cbrt_evm_eq_model_cbrt + (x : Nat) + (hx256 : x < WORD_MOD) : + model_cbrt_evm x = model_cbrt x := by + by_cases hx0 : x = 0 + · subst hx0; decide + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + by_cases hx_small : x < 256 + · exact small_cbrt_evm_eq ⟨x, hx_small⟩ + · -- x ≥ 256: use certificate approach + have hx256_le : 256 ≤ x := Nat.le_of_not_lt hx_small + let m := icbrt x + have hmlo : m * m * m ≤ x := icbrt_cube_le x + have hmhi : x < (m + 1) * (m + 1) * (m + 1) := icbrt_lt_succ_cube x + have hm86 : m < 2 ^ 86 := m_lt_pow86_of_u256 m x hmlo hx256 + -- m ≥ 6 since x ≥ 256 and 6³ = 216 ≤ 256 + have hm6 : 6 ≤ m := by + by_cases hm6 : 6 ≤ m + · exact hm6 + · have hlt : m < 6 := Nat.lt_of_not_ge hm6 + have h6cube : (m + 1) * (m + 1) * (m + 1) ≤ 6 * 6 * 6 := + cube_monotone (by omega : m + 1 ≤ 6) + have : x < 216 := Nat.lt_of_lt_of_le hmhi h6cube + omega + have hm : 0 < m := by omega + -- Map to certificate octave + let n := Nat.log2 x + have hn8 : 8 ≤ n := by + dsimp [n] + by_cases h8 : 8 ≤ Nat.log2 x + · exact h8 + · have hlog := (Nat.log2_eq_iff (Nat.ne_of_gt hx)).1 rfl + have hlt : Nat.log2 x + 1 ≤ 8 := by omega + have hpow : 2 ^ (Nat.log2 x + 1) ≤ 2 ^ 8 := + Nat.pow_le_pow_right (by decide : 1 ≤ 2) hlt + have : x < 256 := Nat.lt_of_lt_of_le hlog.2 (by simpa using hpow) + omega + have hn_lt : n < 256 := (Nat.log2_lt (Nat.ne_of_gt hx)).2 hx256 + have hn_sub_lt : n - certOffset < 248 := by dsimp [n, certOffset]; omega + let idx : Fin 248 := ⟨n - certOffset, hn_sub_lt⟩ + have hidx_plus : idx.val + certOffset = n := by dsimp [idx, certOffset, n]; omega + have hOct : 2 ^ (idx.val + certOffset) ≤ x ∧ x < 2 ^ (idx.val + certOffset + 1) := by + rw [hidx_plus] + exact (Nat.log2_eq_iff (Nat.ne_of_gt hx)).1 rfl + -- Seed and interval + have hseedOf : cbrtSeed x = seedOf idx := CbrtWiring.cbrtSeed_eq_certSeed idx x hOct + have hinterval := CbrtWiring.m_within_cert_interval idx x m hmlo hmhi hOct + -- Define z0..z6 + let z0 := seedOf idx + let z1 := cbrtStep x z0 + let z2 := cbrtStep x z1 + let z3 := cbrtStep x z2 + let z4 := cbrtStep x z3 + let z5 := cbrtStep x z4 + let z6 := cbrtStep x z5 + have hsPos : 0 < z0 := seed_pos idx + -- Lower bounds via floor bound + have hmz1 : m ≤ z1 := by + dsimp [z1, z0] + exact cbrt_step_floor_bound x (seedOf idx) m hsPos hmlo + have hz1Pos : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hmz2 : m ≤ z2 := by + dsimp [z2]; exact cbrt_step_floor_bound x z1 m hz1Pos hmlo + have hz2Pos : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hmz3 : m ≤ z3 := by + dsimp [z3]; exact cbrt_step_floor_bound x z2 m hz2Pos hmlo + have hz3Pos : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hmz4 : m ≤ z4 := by + dsimp [z4]; exact cbrt_step_floor_bound x z3 m hz3Pos hmlo + have hz4Pos : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hmz5 : m ≤ z5 := by + dsimp [z5]; exact cbrt_step_floor_bound x z4 m hz4Pos hmlo + have hz5Pos : 0 < z5 := Nat.lt_of_lt_of_le hm hmz5 + -- Error bounds from certificate chain + have hm2 : 2 ≤ m := Nat.le_trans (lo_ge_two idx) hinterval.1 + have hloPos : 0 < loOf idx := lo_pos idx + -- Step 1: d1 bound from analytic formula + have hd1 : z1 - m ≤ d1Of idx := by + have h := CbrtCertified.cbrt_d1_bound x m (seedOf idx) (loOf idx) (hiOf idx) + hsPos hmlo hmhi hinterval.1 hinterval.2 + simp only at h + show cbrtStep x (seedOf idx) - m ≤ d1Of idx + have hd1eq := d1_eq idx + have hmaxeq := maxabs_eq idx + rw [hmaxeq] at hd1eq + rw [← hd1eq] at h + exact h + have h2d1 : 2 * d1Of idx ≤ m := Nat.le_trans (two_d1_le_lo idx) hinterval.1 + -- Steps 2-5 via step_from_bound + have hd2 : z2 - m ≤ d2Of idx := by + have h := CbrtCertified.step_from_bound x m (loOf idx) z1 (d1Of idx) hm2 hloPos + hinterval.1 hmhi hmz1 hd1 h2d1 + show cbrtStep x z1 - m ≤ d2Of idx; unfold d2Of; exact h + have h2d2 : 2 * d2Of idx ≤ m := Nat.le_trans (two_d2_le_lo idx) hinterval.1 + have hd3 : z3 - m ≤ d3Of idx := by + have h := CbrtCertified.step_from_bound x m (loOf idx) z2 (d2Of idx) hm2 hloPos + hinterval.1 hmhi hmz2 hd2 h2d2 + show cbrtStep x z2 - m ≤ d3Of idx; unfold d3Of; exact h + have h2d3 : 2 * d3Of idx ≤ m := Nat.le_trans (two_d3_le_lo idx) hinterval.1 + have hd4 : z4 - m ≤ d4Of idx := by + have h := CbrtCertified.step_from_bound x m (loOf idx) z3 (d3Of idx) hm2 hloPos + hinterval.1 hmhi hmz3 hd3 h2d3 + show cbrtStep x z3 - m ≤ d4Of idx; unfold d4Of; exact h + have h2d4 : 2 * d4Of idx ≤ m := Nat.le_trans (two_d4_le_lo idx) hinterval.1 + have hd5 : z5 - m ≤ d5Of idx := by + have h := CbrtCertified.step_from_bound x m (loOf idx) z4 (d4Of idx) hm2 hloPos + hinterval.1 hmhi hmz4 hd4 h2d4 + show cbrtStep x z4 - m ≤ d5Of idx; unfold d5Of; exact h + have h2d5 : 2 * d5Of idx ≤ m := Nat.le_trans (two_d5_le_lo idx) hinterval.1 + -- Upper bounds: z_k ≤ 2m (from error ≤ d_k ≤ lo/2 ≤ m/2, so z_k ≤ m + m/2 < 2m) + -- Actually: 2*d_k ≤ lo ≤ m, so d_k ≤ m/2, so z_k ≤ m + d_k ≤ m + m = 2m + have hd1m : d1Of idx ≤ m := by omega + have hd2m : d2Of idx ≤ m := by omega + have hd3m : d3Of idx ≤ m := by omega + have hd4m : d4Of idx ≤ m := by omega + have hd5m : d5Of idx ≤ m := by omega + have hz1_le_2m : z1 ≤ 2 * m := by omega + have hz2_le_2m : z2 ≤ 2 * m := by omega + have hz3_le_2m : z3 ≤ 2 * m := by omega + have hz4_le_2m : z4 ≤ 2 * m := by omega + have hz5_le_2m : z5 ≤ 2 * m := by omega + -- z_k < 2^87 (from z_k ≤ 2m < 2^87) + have hz1_87 : z1 < 2 ^ 87 := by omega + have hz2_87 : z2 < 2 ^ 87 := by omega + have hz3_87 : z3 < 2 ^ 87 := by omega + have hz4_87 : z4 < 2 ^ 87 := by omega + have hz5_87 : z5 < 2 ^ 87 := by omega + -- z_k * z_k < WORD_MOD (from z_k < 2^87) + have hzz1 : z1 * z1 < WORD_MOD := zsq_lt_word_of_lt_87 z1 hz1_87 + have hzz2 : z2 * z2 < WORD_MOD := zsq_lt_word_of_lt_87 z2 hz2_87 + have hzz3 : z3 * z3 < WORD_MOD := zsq_lt_word_of_lt_87 z3 hz3_87 + have hzz4 : z4 * z4 < WORD_MOD := zsq_lt_word_of_lt_87 z4 hz4_87 + have hzz5 : z5 * z5 < WORD_MOD := zsq_lt_word_of_lt_87 z5 hz5_87 + -- x/(z_k*z_k) + 2*z_k < WORD_MOD (from cbrt_sum_lt_word_of_bounds) + have hsum1 : x / (z1 * z1) + 2 * z1 < WORD_MOD := + cbrt_sum_lt_word_of_bounds x m z1 hx256 hm hmlo hmhi hmz1 hz1_le_2m + have hsum2 : x / (z2 * z2) + 2 * z2 < WORD_MOD := + cbrt_sum_lt_word_of_bounds x m z2 hx256 hm hmlo hmhi hmz2 hz2_le_2m + have hsum3 : x / (z3 * z3) + 2 * z3 < WORD_MOD := + cbrt_sum_lt_word_of_bounds x m z3 hx256 hm hmlo hmhi hmz3 hz3_le_2m + have hsum4 : x / (z4 * z4) + 2 * z4 < WORD_MOD := + cbrt_sum_lt_word_of_bounds x m z4 hx256 hm hmlo hmhi hmz4 hz4_le_2m + have hsum5 : x / (z5 * z5) + 2 * z5 < WORD_MOD := + cbrt_sum_lt_word_of_bounds x m z5 hx256 hm hmlo hmhi hmz5 hz5_le_2m + -- Seed step: z0*z0 < WORD_MOD and x/(z0*z0) + 2*z0 < WORD_MOD + have hzz0 : z0 * z0 < WORD_MOD := seed_sq_lt_word idx + have hsum0 : x / (z0 * z0) + 2 * z0 < WORD_MOD := by + have hseed_bound := seed_sum_lt_word idx + have hxup : x < 2 ^ (idx.val + certOffset + 1) := hOct.2 + have hx_div_le : x / (seedOf idx * seedOf idx) ≤ + (2 ^ (idx.val + certOffset + 1) - 1) / (seedOf idx * seedOf idx) := by + exact Nat.div_le_div_right (by omega) + calc x / (z0 * z0) + 2 * z0 + ≤ (2 ^ (idx.val + certOffset + 1) - 1) / (seedOf idx * seedOf idx) + 2 * seedOf idx := + Nat.add_le_add_right hx_div_le _ + _ < WORD_MOD := hseed_bound + -- z_k < WORD_MOD + have hz0W : z0 < WORD_MOD := by + have hle : z0 ≤ z0 * z0 := by + calc z0 = z0 * 1 := by omega + _ ≤ z0 * z0 := Nat.mul_le_mul_left z0 (Nat.succ_le_of_lt hsPos) + exact Nat.lt_of_le_of_lt hle hzz0 + have hz1W : z1 < WORD_MOD := Nat.lt_of_lt_of_le hz1_87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + have hz2W : z2 < WORD_MOD := Nat.lt_of_lt_of_le hz2_87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + have hz3W : z3 < WORD_MOD := Nat.lt_of_lt_of_le hz3_87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + have hz4W : z4 < WORD_MOD := Nat.lt_of_lt_of_le hz4_87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + have hz5W : z5 < WORD_MOD := Nat.lt_of_lt_of_le hz5_87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + -- EVM step = norm step for each iteration + have hstep1 : evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z0 z0)) z0) z0) 3 = z1 := by + have h := step_evm_eq_norm_of_safe x z0 hx256 hsPos hz0W hzz0 hsum0 + simpa [z1, normStep_eq_cbrtStep] using h + have hstep2 : evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z1 z1)) z1) z1) 3 = z2 := by + have h := step_evm_eq_norm_of_safe x z1 hx256 hz1Pos hz1W hzz1 hsum1 + simpa [z2, normStep_eq_cbrtStep] using h + have hstep3 : evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z2 z2)) z2) z2) 3 = z3 := by + have h := step_evm_eq_norm_of_safe x z2 hx256 hz2Pos hz2W hzz2 hsum2 + simpa [z3, normStep_eq_cbrtStep] using h + have hstep4 : evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z3 z3)) z3) z3) 3 = z4 := by + have h := step_evm_eq_norm_of_safe x z3 hx256 hz3Pos hz3W hzz3 hsum3 + simpa [z4, normStep_eq_cbrtStep] using h + have hstep5 : evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z4 z4)) z4) z4) 3 = z5 := by + have h := step_evm_eq_norm_of_safe x z4 hx256 hz4Pos hz4W hzz4 hsum4 + simpa [z5, normStep_eq_cbrtStep] using h + have hstep6 : evmDiv (evmAdd (evmAdd (evmDiv x (evmMul z5 z5)) z5) z5) 3 = z6 := by + have h := step_evm_eq_norm_of_safe x z5 hx256 hz5Pos hz5W hzz5 hsum5 + simpa [z6, normStep_eq_cbrtStep] using h + -- Seed: EVM = norm + have hseedNorm : + normAdd 1 (normShr 8 (normShl (normDiv (normBitLengthPlus1 x) 3) 233)) = + seedOf idx := by + exact (normSeed_eq_cbrtSeed_of_pos x hx).trans hseedOf + have hseedEvm : + evmAdd 1 (evmShr 8 (evmShl (evmDiv (evmSub 257 (evmClz x)) 3) 233)) = + seedOf idx := by + have hOldNorm := (normSub257Clz_eq_cbrtSeed_of_pos x hx hx256).trans hseedOf + exact (seed_evm_eq_norm x hx256).trans hOldNorm + -- Final assembly + have hxmod : u256 x = x := u256_eq_of_lt x hx256 + unfold model_cbrt_evm model_cbrt + simp [hxmod, hseedEvm, hseedNorm, z0, z1, z2, z3, z4, z5, z6, + hstep1, hstep2, hstep3, hstep4, hstep5, hstep6, normStep_eq_cbrtStep] + +-- Bracket result for the EVM model +theorem model_cbrt_evm_bracket_u256_all + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + let m := icbrt x + m ≤ model_cbrt_evm x ∧ model_cbrt_evm x ≤ m + 1 := by + have hxW : x < WORD_MOD := by simpa [WORD_MOD] using hx256 + simpa [model_cbrt_evm_eq_model_cbrt x hxW] using model_cbrt_bracket_u256_all x hx hx256 + +-- ============================================================================ +-- Level 3: Floor correction +-- ============================================================================ + +private theorem floor_correction_norm_eq_if (x z : Nat) : + normSub z (normLt (normDiv x (normMul z z)) z) = + (if x / (z * z) < z then z - 1 else z) := by + by_cases hz0 : z = 0 + · subst hz0; simp [normSub, normLt, normDiv, normMul] + · by_cases hlt : x / (z * z) < z + · simp [normSub, normLt, normDiv, normMul, hlt] + · simp [normSub, normLt, normDiv, normMul, hlt] + +theorem model_cbrt_floor_eq_floorCbrt + (x : Nat) : + model_cbrt_floor x = floorCbrt x := by + have hinner : model_cbrt x = innerCbrt x := model_cbrt_eq_innerCbrt x + unfold model_cbrt_floor floorCbrt + simp [hinner, floor_correction_norm_eq_if] + +private theorem normLt_div_zsq_le (x z : Nat) : + normLt (normDiv x (normMul z z)) z ≤ z := by + by_cases hz0 : z = 0 + · simp [normLt, normDiv, normMul, hz0] + · by_cases hlt : x / (z * z) < z + · simp [normLt, normDiv, normMul, hlt] + exact Nat.succ_le_of_lt (Nat.pos_of_ne_zero hz0) + · simp [normLt, normDiv, normMul, hlt] + +private theorem floor_step_evm_eq_norm + (x z : Nat) (hx : x < WORD_MOD) (hz : z < WORD_MOD) (hzzW : z * z < WORD_MOD) : + evmSub z (evmLt (evmDiv x (evmMul z z)) z) = + normSub z (normLt (normDiv x (normMul z z)) z) := by + have hmul : evmMul z z = normMul z z := + evmMul_eq_normMul_of_no_overflow z z hz hz hzzW + have hmulLt : normMul z z < WORD_MOD := by simpa [normMul] using hzzW + have hdiv : evmDiv x (evmMul z z) = normDiv x (normMul z z) := by + rw [hmul]; exact evmDiv_eq_normDiv_of_u256 x (normMul z z) hx hmulLt + have hdivLt : normDiv x (normMul z z) < WORD_MOD := + Nat.lt_of_le_of_lt (by simp [normDiv, normMul]; exact Nat.div_le_self _ _) hx + have hlt : evmLt (evmDiv x (evmMul z z)) z = normLt (normDiv x (normMul z z)) z := by + simpa [hdiv] using evmLt_eq_normLt_of_u256 (normDiv x (normMul z z)) z hdivLt hz + have hbLe : normLt (normDiv x (normMul z z)) z ≤ z := normLt_div_zsq_le x z + calc evmSub z (evmLt (evmDiv x (evmMul z z)) z) + = evmSub z (normLt (normDiv x (normMul z z)) z) := by rw [hlt] + _ = normSub z (normLt (normDiv x (normMul z z)) z) := + evmSub_eq_normSub_of_le z (normLt (normDiv x (normMul z z)) z) hz hbLe + +theorem model_cbrt_floor_evm_eq_model_cbrt_floor + (x : Nat) (hxW : x < WORD_MOD) : + model_cbrt_floor_evm x = model_cbrt_floor x := by + have hx256 : x < 2 ^ 256 := by simpa [WORD_MOD] using hxW + have hxmod : u256 x = x := u256_eq_of_lt x hxW + by_cases hx0 : x = 0 + · subst hx0; decide + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + have hbr := model_cbrt_evm_bracket_u256_all x hx hx256 + have hm86 : icbrt x < 2 ^ 86 := m_lt_pow86_of_u256 (icbrt x) x (icbrt_cube_le x) hxW + have hz87 : model_cbrt_evm x < 2 ^ 87 := by + have : model_cbrt_evm x ≤ icbrt x + 1 := hbr.2; omega + have hzW : model_cbrt_evm x < WORD_MOD := + Nat.lt_of_lt_of_le hz87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + have hzzW : model_cbrt_evm x * model_cbrt_evm x < WORD_MOD := + zsq_lt_word_of_lt_87 (model_cbrt_evm x) hz87 + have hroot : model_cbrt_evm x = model_cbrt x := model_cbrt_evm_eq_model_cbrt x hxW + unfold model_cbrt_floor_evm model_cbrt_floor + simp [hxmod] + simpa [hroot] using floor_step_evm_eq_norm x (model_cbrt_evm x) hxW hzW hzzW + +theorem model_cbrt_floor_evm_eq_floorCbrt + (x : Nat) (hx256 : x < 2 ^ 256) : + model_cbrt_floor_evm x = floorCbrt x := by + have hxW : x < WORD_MOD := by simpa [WORD_MOD] using hx256 + calc model_cbrt_floor_evm x + = model_cbrt_floor x := model_cbrt_floor_evm_eq_model_cbrt_floor x hxW + _ = floorCbrt x := model_cbrt_floor_eq_floorCbrt x + +-- Combined with Wiring's floorCbrt_correct_u256: +theorem model_cbrt_floor_evm_correct + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + model_cbrt_floor_evm x = icbrt x := by + calc model_cbrt_floor_evm x + = floorCbrt x := model_cbrt_floor_evm_eq_floorCbrt x hx256 + _ = icbrt x := floorCbrt_correct_u256 x hx hx256 + +-- ============================================================================ +-- Level 4: cbrtUp +-- ============================================================================ + +/-- Specification-level model for `cbrtUp`: round `innerCbrt` upward if needed. -/ +def cbrtUpSpec (x : Nat) : Nat := + let z := innerCbrt x + if z * z * z < x then z + 1 else z + +-- The Nat-level cbrtUp spec equivalence. +-- Trivial with the new model: z * (z * z) = z * z * z by associativity, +-- so normLt(normMul z (normMul z z), x) = if z*z*z < x then 1 else 0. +private theorem model_cbrt_up_norm_eq_cbrtUpSpec + (x : Nat) : + model_cbrt_up x = cbrtUpSpec x := by + have hinner : model_cbrt x = innerCbrt x := model_cbrt_eq_innerCbrt x + unfold model_cbrt_up cbrtUpSpec + simp only [hinner, normAdd, normLt, normMul, Nat.mul_assoc] + -- Both sides now have z*(z*z). Just need: z + (if _ then 1 else 0) = if _ then z+1 else z + split <;> simp_all + +theorem model_cbrt_up_eq_cbrtUpSpec + (x : Nat) : + model_cbrt_up x = cbrtUpSpec x := + model_cbrt_up_norm_eq_cbrtUpSpec x + +-- EVM cbrtUp = cbrtUpSpec. +-- Key overflow facts for the new model (z + lt(mul(z, mul(z, z)), x)): +-- z = model_cbrt_evm x ∈ [m, m+1], m < 2^86, so z < 2^87 +-- z² < 2^174 < 2^256 (no overflow in inner mul) +-- z³ < 2^256 (proven in OverflowSafety via innerCbrt_cube_lt_word) +-- lt(...) ≤ 1, z + 1 < 2^87 + 1 < 2^256 (no overflow in final add) +theorem model_cbrt_up_evm_eq_cbrtUpSpec + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + model_cbrt_up_evm x = cbrtUpSpec x := by + -- Strategy: show model_cbrt_up_evm x = model_cbrt_up x, then use the Nat proof. + have hxW : x < WORD_MOD := by simpa [WORD_MOD] using hx256 + have hroot : model_cbrt_evm x = model_cbrt x := model_cbrt_evm_eq_model_cbrt x hxW + have hxmod : u256 x = x := u256_eq_of_lt x hxW + have hinner : model_cbrt x = innerCbrt x := model_cbrt_eq_innerCbrt x + have hbr := model_cbrt_evm_bracket_u256_all x hx hx256 + have hm86 : icbrt x < 2 ^ 86 := m_lt_pow86_of_u256 (icbrt x) x (icbrt_cube_le x) hxW + have hz87 : innerCbrt x < 2 ^ 87 := by + have := hbr.2; rw [hroot, hinner] at this; omega + have hzW : innerCbrt x < WORD_MOD := + Nat.lt_of_lt_of_le hz87 (Nat.le_of_lt (two_pow_lt_word 87 (by decide))) + have hzzW : innerCbrt x * innerCbrt x < WORD_MOD := zsq_lt_word_of_lt_87 _ hz87 + -- z * (z * z) < WORD_MOD (the key new overflow fact) + have hcubeW : innerCbrt x * (innerCbrt x * innerCbrt x) < WORD_MOD := by + have := CbrtOverflow.innerCbrt_cube_lt_word x hx hx256 + simpa [WORD_MOD] using this + have hup_nat : model_cbrt_up x = cbrtUpSpec x := + model_cbrt_up_norm_eq_cbrtUpSpec x + -- Show model_cbrt_up_evm x = model_cbrt_up x. + suffices h : model_cbrt_up_evm x = model_cbrt_up x by rw [h]; exact hup_nat + unfold model_cbrt_up_evm model_cbrt_up + simp only [hxmod, hroot, hinner] + -- Goal: evmAdd z (evmLt (evmMul z (evmMul z z)) x) + -- = normAdd z (normLt (normMul z (normMul z z)) x) + -- where z = innerCbrt x. + -- 1. evmMul z z = normMul z z (z² < WORD_MOD) + have hmul_zz : evmMul (innerCbrt x) (innerCbrt x) = normMul (innerCbrt x) (innerCbrt x) := + evmMul_eq_normMul_of_no_overflow _ _ hzW hzW hzzW + rw [hmul_zz] + -- 2. evmMul z (normMul z z) = normMul z (normMul z z) (z³ < WORD_MOD) + have hmulLt : normMul (innerCbrt x) (innerCbrt x) < WORD_MOD := by + simpa [normMul] using hzzW + have hcube_mul : evmMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x)) = + normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x)) := by + have hprod : innerCbrt x * normMul (innerCbrt x) (innerCbrt x) < WORD_MOD := by + simp [normMul]; exact hcubeW + exact evmMul_eq_normMul_of_no_overflow _ _ hzW hmulLt hprod + rw [hcube_mul] + -- 3. evmLt (normMul z (normMul z z)) x = normLt (normMul z (normMul z z)) x + have hcubeLt : normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x)) < WORD_MOD := by + simp [normMul]; exact hcubeW + have hlt_eq : evmLt (normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x))) x = + normLt (normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x))) x := + evmLt_eq_normLt_of_u256 _ x hcubeLt hxW + rw [hlt_eq] + -- 4. evmAdd z (normLt ...) = normAdd z (normLt ...) + have hltVal : normLt (normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x))) x ≤ 1 := by + unfold normLt; split <;> omega + have hltLt : normLt (normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x))) x < WORD_MOD := + Nat.lt_of_le_of_lt hltVal one_lt_word + have hfinalLt : innerCbrt x + normLt (normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x))) x < WORD_MOD := by + have h87W : 2 ^ 87 + 1 < WORD_MOD := by unfold WORD_MOD; decide + calc innerCbrt x + normLt (normMul (innerCbrt x) (normMul (innerCbrt x) (innerCbrt x))) x + ≤ innerCbrt x + 1 := Nat.add_le_add_left hltVal _ + _ ≤ 2 ^ 87 + 1 := Nat.add_le_add_right (Nat.le_of_lt hz87) _ + _ < WORD_MOD := h87W + exact evmAdd_eq_normAdd_of_no_overflow _ _ hzW hltLt hfinalLt + +-- ============================================================================ +-- Level 4b: cbrtUp upper-bound correctness +-- ============================================================================ + +/-- cbrtUpSpec gives a valid upper bound: x ≤ r³. -/ +theorem cbrtUpSpec_upper_bound + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + x ≤ cbrtUpSpec x * cbrtUpSpec x * cbrtUpSpec x := by + let m := icbrt x + have hmlo : m * m * m ≤ x := icbrt_cube_le x + have hmhi : x < (m + 1) * (m + 1) * (m + 1) := icbrt_lt_succ_cube x + have hbr : m ≤ innerCbrt x ∧ innerCbrt x ≤ m + 1 := by + constructor + · exact innerCbrt_lower x m hx hmlo + · exact innerCbrt_upper_u256 x hx hx256 + unfold cbrtUpSpec + by_cases hlt : innerCbrt x * innerCbrt x * innerCbrt x < x + · simp [hlt] + -- innerCbrt x = m (otherwise (m+1)³ < x, contradicting hmhi) + have hzm : innerCbrt x = m := by + have hneq : innerCbrt x ≠ m + 1 := by + intro hce; rw [hce] at hlt; omega + omega + rw [hzm]; exact Nat.le_of_lt hmhi + · simp [hlt]; exact Nat.le_of_not_gt hlt + +theorem model_cbrt_up_evm_upper_bound + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + x ≤ model_cbrt_up_evm x * model_cbrt_up_evm x * model_cbrt_up_evm x := by + rw [model_cbrt_up_evm_eq_cbrtUpSpec x hx hx256] + exact cbrtUpSpec_upper_bound x hx hx256 + +-- ============================================================================ +-- Level 4c: cbrtUp lower bound (exact ceiling) +-- ============================================================================ + +/-- cbrtUpSpec gives a tight lower bound: (r-1)³ < x. + Combined with the upper bound (x ≤ r³), this shows r = ⌈∛x⌉. -/ +theorem cbrtUpSpec_lower_bound + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + (cbrtUpSpec x - 1) * (cbrtUpSpec x - 1) * (cbrtUpSpec x - 1) < x := by + have hmlo : icbrt x * icbrt x * icbrt x ≤ x := icbrt_cube_le x + have hupper : innerCbrt x ≤ icbrt x + 1 := innerCbrt_upper_u256 x hx hx256 + have hlower : icbrt x ≤ innerCbrt x := innerCbrt_lower x (icbrt x) hx hmlo + unfold cbrtUpSpec + by_cases hlt : innerCbrt x * innerCbrt x * innerCbrt x < x + · -- innerCbrt(x)³ < x: cbrtUpSpec = innerCbrt(x) + 1, (innerCbrt(x)+1-1)³ = innerCbrt(x)³ < x + simp [hlt] + · -- innerCbrt(x)³ ≥ x: cbrtUpSpec = innerCbrt(x) + simp [hlt] + -- Need: (innerCbrt(x) - 1)³ < x. Case split: innerCbrt(x) = icbrt(x) or icbrt(x)+1. + have hcases := innerCbrt_correct_of_upper x hx hupper + rcases hcases with heqm | heqm1 + · -- innerCbrt(x) = icbrt(x) = m. Need (m-1)³ < x. + rw [heqm] + -- m > 0 since x > 0 implies icbrt(x) > 0 + have hm_pos : 0 < icbrt x := by + by_cases h0 : icbrt x = 0 + · -- icbrt(x) = 0 means 0³ ≤ x < 1³ = 1, so x = 0, contradicting hx > 0. + have := icbrt_lt_succ_cube x; rw [h0] at this; simp at this; omega + · exact Nat.pos_of_ne_zero h0 + -- (m-1)³ < m³ ≤ x + have : (icbrt x - 1) * (icbrt x - 1) * (icbrt x - 1) < + icbrt x * icbrt x * icbrt x := by + have hpred : icbrt x - 1 < icbrt x := Nat.sub_lt hm_pos (by omega) + -- (m-1)³ ≤ (m-1)² * m < m² * m = m³ + calc (icbrt x - 1) * (icbrt x - 1) * (icbrt x - 1) + ≤ (icbrt x - 1) * (icbrt x - 1) * icbrt x := + Nat.mul_le_mul_left _ (Nat.le_of_lt hpred) + _ ≤ (icbrt x - 1) * icbrt x * icbrt x := + Nat.mul_le_mul_right _ (Nat.mul_le_mul_left _ (Nat.le_of_lt hpred)) + _ < icbrt x * icbrt x * icbrt x := + Nat.mul_lt_mul_of_pos_right + (Nat.mul_lt_mul_of_pos_right hpred hm_pos) + hm_pos + omega + · -- innerCbrt(x) = icbrt(x) + 1. Need (icbrt(x))³ < x. + rw [heqm1]; simp + -- Since innerCbrt(x) = icbrt(x)+1 and innerCbrt(m³) = m for m = icbrt(x), + -- x ≠ icbrt(x)³. Combined with icbrt(x)³ ≤ x: strict inequality. + have hm_pos : 0 < icbrt x := by + by_cases h0 : icbrt x = 0 + · have := icbrt_lt_succ_cube x; rw [h0] at this; simp at this; omega + · exact Nat.pos_of_ne_zero h0 + have hx_ne : x ≠ icbrt x * icbrt x * icbrt x := by + intro hxeq + have hpc := CbrtWiring.innerCbrt_on_perfect_cube (icbrt x) + hm_pos (by rw [← hxeq]; exact hx256) + -- hpc : innerCbrt (icbrt x * icbrt x * icbrt x) = icbrt x + -- heqm1 : innerCbrt x = icbrt x + 1 + -- From hxeq: x = icbrt x * icbrt x * icbrt x + have : innerCbrt (icbrt x * icbrt x * icbrt x) = icbrt x + 1 := by + rwa [← hxeq] + -- Contradiction: icbrt x = icbrt x + 1 + omega + omega + +/-- The EVM cbrtUp model gives a tight lower bound: (r-1)³ < x. -/ +theorem model_cbrt_up_evm_lower_bound + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + (model_cbrt_up_evm x - 1) * (model_cbrt_up_evm x - 1) * (model_cbrt_up_evm x - 1) < x := by + rw [model_cbrt_up_evm_eq_cbrtUpSpec x hx hx256] + exact cbrtUpSpec_lower_bound x hx hx256 + +/-- Combined: the EVM cbrtUp model gives the exact ceiling cube root. -/ +theorem model_cbrt_up_evm_is_ceil + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + let r := model_cbrt_up_evm x + (r - 1) * (r - 1) * (r - 1) < x ∧ x ≤ r * r * r := by + exact ⟨model_cbrt_up_evm_lower_bound x hx hx256, + model_cbrt_up_evm_upper_bound x hx hx256⟩ + +/-- cbrtUp is correct for ALL x < 2^256 (including x = 0). -/ +theorem model_cbrt_up_evm_is_ceil_all + (x : Nat) (hx256 : x < 2 ^ 256) : + let r := model_cbrt_up_evm x + x ≤ r * r * r ∧ (r = 0 ∨ (r - 1) * (r - 1) * (r - 1) < x) := by + by_cases hx : 0 < x + · have ⟨hlo, hhi⟩ := model_cbrt_up_evm_is_ceil x hx hx256 + exact ⟨hhi, Or.inr hlo⟩ + · simp at hx + subst hx + decide + +-- ============================================================================ +-- Level 4d: cbrtUp minimality (smallest integer with r³ ≥ x) +-- ============================================================================ + +/-- If `r = 0` or `(r-1)³ < x`, then `r` is the smallest value whose cube is ≥ x. -/ +private theorem minimal_of_pred_cube_lt + (x r : Nat) + (hpred : r = 0 ∨ (r - 1) * (r - 1) * (r - 1) < x) : + ∀ y, x ≤ y * y * y → r ≤ y := by + intro y hy + by_cases hry : r ≤ y + · exact hry + · have hylt : y < r := Nat.lt_of_not_ge hry + cases hpred with + | inl hr0 => + exact False.elim ((Nat.not_lt_of_ge hylt) (by simp [hr0])) + | inr hpredlt => + have hyle : y ≤ r - 1 := by omega + have hycube : y * y * y ≤ (r - 1) * (r - 1) * (r - 1) := cube_monotone hyle + have hcontra : x ≤ (r - 1) * (r - 1) * (r - 1) := Nat.le_trans hy hycube + exact False.elim ((Nat.not_lt_of_ge hcontra) hpredlt) + +/-- `cbrtUp` is exactly the smallest integer whose cube is ≥ x. + Matches the sqrt analog `model_sqrt_up_evm_ceil_u256`. -/ +theorem model_cbrt_up_evm_ceil_u256 + (x : Nat) + (hx256 : x < 2 ^ 256) : + let r := model_cbrt_up_evm x + x ≤ r * r * r ∧ ∀ y, x ≤ y * y * y → r ≤ y := by + have hceil := model_cbrt_up_evm_is_ceil_all x hx256 + exact ⟨hceil.1, minimal_of_pred_cube_lt x (model_cbrt_up_evm x) hceil.2⟩ + +-- ============================================================================ +-- Summary +-- ============================================================================ + +/- + PROOF STATUS: + + ✓ normStep_eq_cbrtStep: NR step norm = cbrtStep + ✓ normSeed_eq_cbrtSeed_of_pos: norm seed = cbrtSeed (no uint256 bound) + ✓ normSub257Clz_eq_cbrtSeed_of_pos: old sub/clz seed = cbrtSeed (bridge for EVM) + ✓ model_cbrt_eq_innerCbrt: Nat model = hand-written innerCbrt (no uint256 bound) + ✓ model_cbrt_bracket_u256_all: Nat model ∈ [m, m+1] + ✓ model_cbrt_floor_eq_floorCbrt: Nat floor model = floorCbrt (no uint256 bound) + ✓ model_cbrt_up_eq_cbrtUpSpec: Nat cbrtUp model = cbrtUpSpec (no uint256 bound) + ✓ model_cbrt_up_evm_eq_cbrtUpSpec: EVM cbrtUp model = cbrtUpSpec + ✓ cbrtUpSpec_upper_bound: cbrtUpSpec gives valid upper bound + ✓ cbrtUpSpec_lower_bound: cbrtUpSpec gives tight lower bound (exact ceiling) + ✓ model_cbrt_up_evm_upper_bound: EVM cbrtUp gives valid upper bound + ✓ model_cbrt_up_evm_lower_bound: EVM cbrtUp gives tight lower bound + ✓ model_cbrt_up_evm_is_ceil: EVM cbrtUp is the exact ceiling cube root (x > 0) + ✓ model_cbrt_up_evm_is_ceil_all: EVM cbrtUp is correct for all x < 2^256 (including x = 0) + ✓ model_cbrt_up_evm_ceil_u256: cbrtUp is the smallest integer with r³ ≥ x + ✓ model_cbrt_evm_eq_model_cbrt: EVM model = Nat model + ✓ model_cbrt_evm_bracket_u256_all: EVM model ∈ [m, m+1] + ✓ model_cbrt_floor_evm_eq_floorCbrt: EVM floor = floorCbrt + ✓ model_cbrt_floor_evm_correct: EVM floor = icbrt +-/ + +end CbrtGeneratedModel diff --git a/formal/cbrt/CbrtProof/CbrtProof/OverflowSafety.lean b/formal/cbrt/CbrtProof/CbrtProof/OverflowSafety.lean new file mode 100644 index 000000000..411272e26 --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof/OverflowSafety.lean @@ -0,0 +1,313 @@ +/- + Overflow safety proof for cbrtUp. + + Main theorem: `innerCbrt_cube_lt_word` + For all x < 2^256, innerCbrt(x) * (innerCbrt(x) * innerCbrt(x)) < 2^256. +-/ +import Init +import CbrtProof.CbrtCorrect +import CbrtProof.FiniteCert +import CbrtProof.CertifiedChain +import CbrtProof.Wiring + +set_option exponentiation.threshold 300 + +namespace CbrtOverflow + +open CbrtCert +open CbrtCertified +open CbrtWiring + +-- ============================================================================ +-- Constants +-- ============================================================================ + +private def R_MAX : Nat := 48740834812604276470692694 + +private def hiD1 : Nat := 5865868362315021153806969 +private def hiD2 : Nat := hiD1 * hiD1 / R_MAX + 1 +private def hiD3 : Nat := hiD2 * hiD2 / R_MAX + 1 +private def hiD4 : Nat := hiD3 * hiD3 / R_MAX + 1 +private def hiD5 : Nat := hiD4 * hiD4 / R_MAX + 1 + +-- ============================================================================ +-- Verified constants (kernel-checked via decide, no native_decide) +-- ============================================================================ + +private theorem r_max_cube_lt_word : R_MAX * R_MAX * R_MAX < 2 ^ 256 := by decide +private theorem r_max_succ_cube_ge_word : + 2 ^ 256 ≤ (R_MAX + 1) * (R_MAX + 1) * (R_MAX + 1) := by decide +set_option maxRecDepth 1000000 in +private theorem hiD1_eq : hiD1 = d1Of ⟨247, by omega⟩ := by decide +private theorem hiD5_sq_lt_rmax : hiD5 * hiD5 < R_MAX := by decide +private theorem two_hiD1_le_rmax : 2 * hiD1 ≤ R_MAX := by decide +private theorem two_hiD2_le_rmax : 2 * hiD2 ≤ R_MAX := by decide +private theorem two_hiD3_le_rmax : 2 * hiD3 ≤ R_MAX := by decide +private theorem two_hiD4_le_rmax : 2 * hiD4 ≤ R_MAX := by decide +private theorem two_hiD5_le_rmax : 2 * hiD5 ≤ R_MAX := by decide +private theorem pow255_le_rmax_cube : 2 ^ 255 ≤ R_MAX * R_MAX * R_MAX := by decide +private theorem fBound_at_zero : + (R_MAX + 3) * (R_MAX * R_MAX) ≥ 2 ^ 256 := by decide +private theorem fBound_at_hiD5 : + (R_MAX + 3 - 2 * hiD5) * ((R_MAX + hiD5) * (R_MAX + hiD5)) ≥ 2 ^ 256 := by decide + +-- d1 bound for octave 247 matches the analytic formula (decide) +set_option maxRecDepth 1000000 in +private theorem d1_bound_247 : + (max (seedOf ⟨247, by omega⟩ - loOf ⟨247, by omega⟩) (hiOf ⟨247, by omega⟩ - seedOf ⟨247, by omega⟩) * + max (seedOf ⟨247, by omega⟩ - loOf ⟨247, by omega⟩) (hiOf ⟨247, by omega⟩ - seedOf ⟨247, by omega⟩) * + (hiOf ⟨247, by omega⟩ + 2 * seedOf ⟨247, by omega⟩) + + 3 * hiOf ⟨247, by omega⟩ * (hiOf ⟨247, by omega⟩ + 1)) / + (3 * (seedOf ⟨247, by omega⟩ * seedOf ⟨247, by omega⟩)) = d1Of ⟨247, by omega⟩ := by decide + +-- ============================================================================ +-- Nat polynomial identity: (b-2)(2b+1) + 3b + 2 = 2b² for b ≥ 2 +-- ============================================================================ + +private theorem poly_ident (b : Nat) (hb : 2 ≤ b) : + (b - 2) * (2 * b + 1) + (3 * b + 2) = 2 * b * b := by + generalize hc : b - 2 = c + have hb_eq : b = c + 2 := by omega + subst hb_eq + simp only [Nat.mul_add, Nat.add_mul, Nat.mul_comm, Nat.mul_left_comm, Nat.add_assoc] + omega + +-- ============================================================================ +-- Discrete monotonicity of f(e) = (R+3-2e)(R+e)² +-- ============================================================================ + +/-- f(e) ≥ f(e+1) for e ≥ 1 with 2e ≤ R+1. + Proof: (a+2)b² ≥ a(b+1)² where a = R+1-2e, b = R+e. + Expand: suffices 2b² ≥ a(2b+1). Since a ≤ b-2 and (b-2)(2b+1) ≤ 2b². -/ +private theorem fBound_step_le (e : Nat) (he : 1 ≤ e) (h2e : 2 * e ≤ R_MAX + 1) : + (R_MAX + 3 - 2 * e) * ((R_MAX + e) * (R_MAX + e)) ≥ + (R_MAX + 3 - 2 * (e + 1)) * ((R_MAX + (e + 1)) * (R_MAX + (e + 1))) := by + -- Let a = R+1-2e, b = R+e (avoid `set` which requires Mathlib) + -- Rewrite goal in terms of a, b + have h1 : R_MAX + 3 - 2 * e = (R_MAX + 1 - 2 * e) + 2 := by omega + have h2 : R_MAX + 3 - 2 * (e + 1) = R_MAX + 1 - 2 * e := by omega + have h3 : R_MAX + (e + 1) = (R_MAX + e) + 1 := by omega + rw [h1, h2, h3] + -- Goal: ((R+1-2e)+2) * ((R+e)*(R+e)) ≥ (R+1-2e) * (((R+e)+1)*((R+e)+1)) + -- Abbreviate: a = R+1-2e, b = R+e + -- Goal: (a+2)*(b*b) ≥ a*((b+1)*(b+1)) + -- Suffices: 2*b*b ≥ a*(2*b+1). + suffices hsuff : (R_MAX + 1 - 2 * e) * (2 * (R_MAX + e) + 1) ≤ 2 * (R_MAX + e) * (R_MAX + e) by + -- (a+2)*b² = a*b² + 2*b² + have hexp : ((R_MAX + 1 - 2 * e) + 2) * ((R_MAX + e) * (R_MAX + e)) = + (R_MAX + 1 - 2 * e) * ((R_MAX + e) * (R_MAX + e)) + 2 * ((R_MAX + e) * (R_MAX + e)) := by + simp only [Nat.add_mul] + -- a*((b+1)*(b+1)) = a*(b*b) + a*(2*b+1) + have hexp2 : (R_MAX + 1 - 2 * e) * (((R_MAX + e) + 1) * ((R_MAX + e) + 1)) = + (R_MAX + 1 - 2 * e) * ((R_MAX + e) * (R_MAX + e)) + + (R_MAX + 1 - 2 * e) * (2 * (R_MAX + e) + 1) := by + have : ((R_MAX + e) + 1) * ((R_MAX + e) + 1) = + (R_MAX + e) * (R_MAX + e) + (2 * (R_MAX + e) + 1) := by + simp only [Nat.add_mul, Nat.mul_add, Nat.mul_comm, Nat.add_assoc]; omega + rw [this, Nat.mul_add] + -- Goal: (a+2)*b² ≥ a*(b+1)², i.e., a*(b+1)² ≤ (a+2)*b² + -- After rewriting both sides: a*b² + a*(2b+1) ≤ a*b² + 2*b² + show (R_MAX + 1 - 2 * e) * (((R_MAX + e) + 1) * ((R_MAX + e) + 1)) ≤ + ((R_MAX + 1 - 2 * e) + 2) * ((R_MAX + e) * (R_MAX + e)) + rw [hexp, hexp2] + have hassoc : 2 * (R_MAX + e) * (R_MAX + e) = 2 * ((R_MAX + e) * (R_MAX + e)) := by + rw [Nat.mul_assoc] + rw [hassoc] at hsuff + exact Nat.add_le_add_left hsuff _ + -- Need: a*(2b+1) ≤ 2b². + -- a ≤ b - 2 (since b - a = R+e - R - 1 + 2e = 3e - 1 ≥ 2) + have hab : R_MAX + 1 - 2 * e ≤ (R_MAX + e) - 2 := by omega + have hb2 : 2 ≤ R_MAX + e := by omega + -- (b-2)*(2b+1) + (3b+2) = 2b² (from poly_ident) + have hpoly := poly_ident (R_MAX + e) hb2 + -- So (b-2)*(2b+1) ≤ 2b². + have hbd : ((R_MAX + e) - 2) * (2 * (R_MAX + e) + 1) ≤ 2 * (R_MAX + e) * (R_MAX + e) := by omega + -- a*(2b+1) ≤ (b-2)*(2b+1) ≤ 2b². + exact Nat.le_trans (Nat.mul_le_mul_right _ hab) hbd + +/-- f is non-increasing on [1, hiD5]: for e in this range, f(e) ≥ f(hiD5). -/ +private theorem fBound_ge_endpoint (e : Nat) (he1 : 1 ≤ e) (he2 : e ≤ hiD5) : + (R_MAX + 3 - 2 * e) * ((R_MAX + e) * (R_MAX + e)) ≥ + (R_MAX + 3 - 2 * hiD5) * ((R_MAX + hiD5) * (R_MAX + hiD5)) := by + -- Induction on n = hiD5 - e, generalizing e. + have key : ∀ n, ∀ e', 1 ≤ e' → e' ≤ hiD5 → n = hiD5 - e' → + (R_MAX + 3 - 2 * e') * ((R_MAX + e') * (R_MAX + e')) ≥ + (R_MAX + 3 - 2 * hiD5) * ((R_MAX + hiD5) * (R_MAX + hiD5)) := by + intro n + induction n with + | zero => + intro e' he1' _ hk + have : e' = hiD5 := by omega + subst this + exact Nat.le_refl _ + | succ k ih => + intro e' he1' he2' hk + -- Apply the inductive hypothesis to e' + 1 + have h_ih := ih (e' + 1) (by omega) (by omega) (by omega) + -- f(e') ≥ f(e'+1) by the step lemma + have h_step := fBound_step_le e' he1' (by have := two_hiD5_le_rmax; omega) + -- f(e') ≥ f(e'+1) ≥ f(hiD5) + exact Nat.le_trans h_ih h_step + exact key (hiD5 - e) e he1 he2 rfl + +/-- For all e ∈ [0, hiD5], f(e) ≥ 2^256. -/ +private theorem fBound_ge_word (e : Nat) (he : e ≤ hiD5) : + (R_MAX + 3 - 2 * e) * ((R_MAX + e) * (R_MAX + e)) ≥ 2 ^ 256 := by + by_cases he0 : e = 0 + · subst he0; simp; exact fBound_at_zero + · exact Nat.le_trans fBound_at_hiD5 (fBound_ge_endpoint e (by omega) he) + +-- ============================================================================ +-- cbrtStep bounded by R_MAX when z is close to R_MAX +-- ============================================================================ + +/-- If z ∈ [R_MAX, R_MAX + hiD5] and x < 2^256, then cbrtStep x z ≤ R_MAX. + Proof: x < f(d) = (R+3-2d)(R+d)² gives x/(R+d)² ≤ R+2-2d, + so x/(R+d)² + 2(R+d) ≤ 3R+2, and step ≤ R. -/ +private theorem cbrtStep_le_rmax + (x z : Nat) + (hx : x < 2 ^ 256) + (hmz : R_MAX ≤ z) + (hze : z ≤ R_MAX + hiD5) : + cbrtStep x z ≤ R_MAX := by + unfold cbrtStep + have hd_def : z - R_MAX ≤ hiD5 := by omega + have hzd : z = R_MAX + (z - R_MAX) := by omega + have h2d : 2 * (z - R_MAX) ≤ R_MAX := by have := two_hiD5_le_rmax; omega + -- f(d) ≥ 2^256 > x + have hf := fBound_ge_word (z - R_MAX) hd_def + have hf_gt_x : x < (R_MAX + 3 - 2 * (z - R_MAX)) * ((R_MAX + (z - R_MAX)) * (R_MAX + (z - R_MAX))) := + Nat.lt_of_lt_of_le hx (show 2 ^ 256 ≤ _ from hf) + rw [hzd] + -- x / (R+d)² < R+3-2d (by Nat.div_lt_iff_lt_mul) + have hzz_pos : 0 < (R_MAX + (z - R_MAX)) * (R_MAX + (z - R_MAX)) := + Nat.mul_pos (by unfold R_MAX; omega) (by unfold R_MAX; omega) + have hdiv_lt : x / ((R_MAX + (z - R_MAX)) * (R_MAX + (z - R_MAX))) < R_MAX + 3 - 2 * (z - R_MAX) := + (Nat.div_lt_iff_lt_mul hzz_pos).mpr hf_gt_x + -- So x/(R+d)² ≤ R+2-2d + have hdiv_bound : x / ((R_MAX + (z - R_MAX)) * (R_MAX + (z - R_MAX))) ≤ R_MAX + 2 - 2 * (z - R_MAX) := by omega + -- sum ≤ R+2-2d + 2(R+d) = 3R+2 + have hsum : x / ((R_MAX + (z - R_MAX)) * (R_MAX + (z - R_MAX))) + 2 * (R_MAX + (z - R_MAX)) ≤ 3 * R_MAX + 2 := by omega + -- step = sum/3 ≤ (3R+2)/3 = R + exact Nat.le_trans (Nat.div_le_div_right hsum) (by omega) + +-- ============================================================================ +-- Tighter 5-step chain +-- ============================================================================ + +/-- When m = R_MAX, z₅ ∈ [R_MAX, R_MAX + hiD5]. -/ +private theorem run5_hi_bound + (x : Nat) (hx : x < 2 ^ 256) (_hx_pos : 0 < x) + (hmlo : R_MAX * R_MAX * R_MAX ≤ x) + (hmhi : x < (R_MAX + 1) * (R_MAX + 1) * (R_MAX + 1)) : + R_MAX ≤ run5From x (seedOf ⟨247, by omega⟩) ∧ + run5From x (seedOf ⟨247, by omega⟩) ≤ R_MAX + hiD5 := by + let idx : Fin 248 := ⟨247, by omega⟩ + have hOct : 2 ^ (idx.val + certOffset) ≤ x ∧ x < 2 ^ (idx.val + certOffset + 1) := + ⟨Nat.le_trans pow255_le_rmax_cube hmlo, hx⟩ + have hinterval := m_within_cert_interval idx x R_MAX hmlo hmhi hOct + have hm2 : 2 ≤ R_MAX := by unfold R_MAX; omega + have hsPos : 0 < seedOf idx := seed_pos idx + -- Use run5_certified_bounds from CertifiedChain, but with R_MAX-specific d bounds. + -- We need our own chain because we use R_MAX as the denominator (not loOf idx). + -- The 5-step chain through run5From is definitionally: + -- run5From x s = cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x s)))) + -- We prove bounds step by step. + -- Step 1: floor bound + have hmz1 : R_MAX ≤ cbrtStep x (seedOf idx) := + cbrt_step_floor_bound x (seedOf idx) R_MAX hsPos hmlo + -- d1 bound from certificate + have hd1 : cbrtStep x (seedOf idx) - R_MAX ≤ hiD1 := by + rw [hiD1_eq] + have h := cbrt_d1_bound x R_MAX (seedOf idx) (loOf idx) (hiOf idx) hsPos hmlo hmhi + hinterval.1 hinterval.2 + simp only at h + exact Nat.le_trans h (Nat.le_of_eq d1_bound_247) + -- Steps 2-5 using step_from_bound with R_MAX as both m and lo + have hloPos : 0 < R_MAX := by omega + have hmz2 : R_MAX ≤ cbrtStep x (cbrtStep x (seedOf idx)) := + cbrt_step_floor_bound x _ R_MAX (by omega) hmlo + have hd2 : cbrtStep x (cbrtStep x (seedOf idx)) - R_MAX ≤ hiD2 := + step_from_bound x R_MAX R_MAX _ hiD1 hm2 hloPos (Nat.le_refl _) hmhi hmz1 hd1 two_hiD1_le_rmax + have hmz3 : R_MAX ≤ cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx))) := + cbrt_step_floor_bound x _ R_MAX (by omega) hmlo + have hd3 : cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx))) - R_MAX ≤ hiD3 := + step_from_bound x R_MAX R_MAX _ hiD2 hm2 hloPos (Nat.le_refl _) hmhi hmz2 hd2 two_hiD2_le_rmax + have hmz4 : R_MAX ≤ cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx)))) := + cbrt_step_floor_bound x _ R_MAX (by omega) hmlo + have hd4 : cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx)))) - R_MAX ≤ hiD4 := + step_from_bound x R_MAX R_MAX _ hiD3 hm2 hloPos (Nat.le_refl _) hmhi hmz3 hd3 two_hiD3_le_rmax + -- z5 = run5From x (seedOf idx) + have hmz5 : R_MAX ≤ cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx))))) := + cbrt_step_floor_bound x _ R_MAX (by omega) hmlo + have hd5 : cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx))))) - R_MAX ≤ hiD5 := + step_from_bound x R_MAX R_MAX _ hiD4 hm2 hloPos (Nat.le_refl _) hmhi hmz4 hd4 two_hiD4_le_rmax + -- run5From x (seedOf idx) is definitionally equal to the 5-step chain + have hrun5_def : run5From x (seedOf idx) = + cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (cbrtStep x (seedOf idx))))) := rfl + rw [hrun5_def] + exact ⟨hmz5, by omega⟩ + +-- ============================================================================ +-- Main theorem +-- ============================================================================ + +/-- innerCbrt(x)³ < 2^256 for all x < 2^256 with x > 0. -/ +theorem innerCbrt_cube_lt_word (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + innerCbrt x * (innerCbrt x * innerCbrt x) < 2 ^ 256 := by + rw [← Nat.mul_assoc] + let m := icbrt x + have hmlo : m * m * m ≤ x := icbrt_cube_le x + have hmhi : x < (m + 1) * (m + 1) * (m + 1) := icbrt_lt_succ_cube x + have hbr := innerCbrt_upper_u256 x hx hx256 + rcases innerCbrt_correct_of_upper x hx hbr with heqm | heqm1 + · -- Case z = m: m³ ≤ x < 2^256 + rw [heqm]; exact Nat.lt_of_le_of_lt hmlo hx256 + · -- Case z = m + 1 + rw [heqm1] + by_cases h_succ_lt : (m + 1) * (m + 1) * (m + 1) < 2 ^ 256 + · exact h_succ_lt + · -- (m+1)³ ≥ 2^256 implies m = R_MAX. Derive contradiction. + exfalso + have h_ge : 2 ^ 256 ≤ (m + 1) * (m + 1) * (m + 1) := Nat.le_of_not_lt h_succ_lt + -- m ≤ R_MAX (from m³ ≤ x < 2^256 and (R+1)³ ≥ 2^256) + have hm_le : m ≤ R_MAX := by + by_cases h : m ≤ R_MAX + · exact h + · exfalso + have hR1m : R_MAX + 1 ≤ m := by omega + have hcube : (R_MAX + 1) * (R_MAX + 1) * (R_MAX + 1) ≤ m * m * m := cube_monotone hR1m + have : (R_MAX + 1) * (R_MAX + 1) * (R_MAX + 1) ≤ x := Nat.le_trans hcube hmlo + have : 2 ^ 256 ≤ x := Nat.le_trans r_max_succ_cube_ge_word this + omega + -- m ≥ R_MAX (from (m+1)³ ≥ 2^256 and R_MAX³ < 2^256) + have hm_ge : R_MAX ≤ m := by + by_cases h : R_MAX ≤ m + · exact h + · exfalso + have hm1R : m + 1 ≤ R_MAX := by omega + have hcube : (m + 1) * (m + 1) * (m + 1) ≤ R_MAX * R_MAX * R_MAX := cube_monotone hm1R + have : (m + 1) * (m + 1) * (m + 1) < 2 ^ 256 := + Nat.lt_of_le_of_lt hcube r_max_cube_lt_word + omega + have hm_eq : m = R_MAX := Nat.le_antisymm hm_le hm_ge + -- Rewrite m = R_MAX everywhere + rw [hm_eq] at hmlo hmhi + -- z5 is in [R, R + hiD5] + have ⟨hmz5, hz5⟩ := run5_hi_bound x hx256 hx hmlo hmhi + -- innerCbrt = cbrtStep(x, z5) + have hseed : cbrtSeed x = seedOf ⟨247, by omega⟩ := + cbrtSeed_eq_certSeed _ x ⟨Nat.le_trans pow255_le_rmax_cube hmlo, hx256⟩ + have hinner_eq : innerCbrt x = cbrtStep x (run5From x (seedOf ⟨247, by omega⟩)) := by + rw [innerCbrt_eq_step_run5_seed, hseed] + -- cbrtStep(x, z5) ≤ R_MAX + have hz6 := cbrtStep_le_rmax x _ hx256 hmz5 hz5 + -- innerCbrt(x) ≤ R_MAX + have hinner_le : innerCbrt x ≤ R_MAX := hinner_eq ▸ hz6 + -- But innerCbrt(x) = icbrt(x) + 1 and icbrt(x) = R_MAX + rw [heqm1] at hinner_le + -- hinner_le : icbrt x + 1 ≤ R_MAX, hm_eq : icbrt x = R_MAX (since m := icbrt x) + have : icbrt x = R_MAX := hm_eq + omega + +end CbrtOverflow diff --git a/formal/cbrt/CbrtProof/CbrtProof/Wiring.lean b/formal/cbrt/CbrtProof/CbrtProof/Wiring.lean new file mode 100644 index 000000000..a5c8722c7 --- /dev/null +++ b/formal/cbrt/CbrtProof/CbrtProof/Wiring.lean @@ -0,0 +1,258 @@ +/- + Wiring: connect the finite certificate chain to the unconditional + upper bound `innerCbrt x ≤ icbrt x + 1` for all x < 2^256. + + Strategy: + - For x < 256: use decide (innerCbrt_upper_of_lt_256) + - For x ≥ 256: map x to certificate octave, verify seed/interval match, + apply CbrtCertified.run6_le_m_plus_one +-/ +import Init +import CbrtProof.CbrtCorrect +import CbrtProof.FiniteCert +import CbrtProof.CertifiedChain + +namespace CbrtWiring + +open CbrtCert +open CbrtCertified + +-- ============================================================================ +-- Octave membership: map x to its certificate octave +-- ============================================================================ + +/-- For x > 0, Nat.log2 gives the octave index. -/ +private theorem log2_octave (x : Nat) (hx : x ≠ 0) : + 2 ^ Nat.log2 x ≤ x ∧ x < 2 ^ (Nat.log2 x + 1) := + (Nat.log2_eq_iff hx).1 rfl + +/-- The seed depends only on log2(x), so it matches the certificate seed. -/ +theorem cbrtSeed_eq_certSeed (i : Fin 248) (x : Nat) + (hOct : 2 ^ (i.val + certOffset) ≤ x ∧ x < 2 ^ (i.val + certOffset + 1)) : + cbrtSeed x = seedOf i := by + have hx : 0 < x := Nat.lt_of_lt_of_le (Nat.two_pow_pos (i.val + certOffset)) hOct.1 + have hx0 : x ≠ 0 := Nat.ne_of_gt hx + have hlog : Nat.log2 x = i.val + certOffset := (Nat.log2_eq_iff hx0).2 hOct + unfold cbrtSeed + simp [hlog] + have hseed := seed_eq i + simp [seedOf] at hseed ⊢ + rw [hseed] + +/-- m = icbrt(x) lies within [loOf i, hiOf i] for x in octave i. -/ +theorem m_within_cert_interval + (i : Fin 248) (x m : Nat) + (hmlo : m * m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1) * (m + 1)) + (hOct : 2 ^ (i.val + certOffset) ≤ x ∧ x < 2 ^ (i.val + certOffset + 1)) : + loOf i ≤ m ∧ m ≤ hiOf i := by + have hloSq : loOf i * loOf i * loOf i ≤ 2 ^ (i.val + certOffset) := lo_cube_le_pow2 i + have hloSqX : loOf i * loOf i * loOf i ≤ x := Nat.le_trans hloSq hOct.1 + have hlo : loOf i ≤ m := by + by_cases h : loOf i ≤ m + · exact h + · have hlt : m < loOf i := Nat.lt_of_not_ge h + have hm1 : m + 1 ≤ loOf i := Nat.succ_le_of_lt hlt + have hm1cube : (m + 1) * (m + 1) * (m + 1) ≤ loOf i * loOf i * loOf i := + cube_monotone hm1 + have hm1x : (m + 1) * (m + 1) * (m + 1) ≤ x := Nat.le_trans hm1cube hloSqX + exact False.elim ((Nat.not_lt_of_ge hm1x) hmhi) + have hhiSq : 2 ^ (i.val + certOffset + 1) ≤ + (hiOf i + 1) * (hiOf i + 1) * (hiOf i + 1) := + pow2_succ_le_hi_succ_cube i + have hXHi : x < (hiOf i + 1) * (hiOf i + 1) * (hiOf i + 1) := + Nat.lt_of_lt_of_le hOct.2 hhiSq + have hhi : m ≤ hiOf i := by + by_cases h : m ≤ hiOf i + · exact h + · have hlt : hiOf i < m := Nat.lt_of_not_ge h + have hhi1 : hiOf i + 1 ≤ m := Nat.succ_le_of_lt hlt + have hhicube : (hiOf i + 1) * (hiOf i + 1) * (hiOf i + 1) ≤ m * m * m := + cube_monotone hhi1 + have hXmm : x < m * m * m := Nat.lt_of_lt_of_le hXHi hhicube + exact False.elim ((Nat.not_lt_of_ge hmlo) hXmm) + exact ⟨hlo, hhi⟩ + +-- ============================================================================ +-- Certificate-backed upper bound +-- ============================================================================ + +/-- Certificate-backed upper bound for a single octave. + If x is in certificate octave i with m = icbrt(x), then innerCbrt x ≤ m + 1. -/ +theorem innerCbrt_upper_of_octave + (i : Fin 248) (x m : Nat) + (hmlo : m * m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1) * (m + 1)) + (hOct : 2 ^ (i.val + certOffset) ≤ x ∧ x < 2 ^ (i.val + certOffset + 1)) : + innerCbrt x ≤ m + 1 := by + have hx : 0 < x := Nat.lt_of_lt_of_le (Nat.two_pow_pos _) hOct.1 + have hinterval := m_within_cert_interval i x m hmlo hmhi hOct + have hm2 : 2 ≤ m := Nat.le_trans (lo_ge_two i) hinterval.1 + have hseed : cbrtSeed x = seedOf i := cbrtSeed_eq_certSeed i x hOct + -- innerCbrt x = run6From x (cbrtSeed x) = run6From x (seedOf i) + have hrun : run6From x (seedOf i) ≤ m + 1 := + run6_le_m_plus_one i x m hm2 hmlo hmhi hinterval.1 hinterval.2 + -- Connect innerCbrt to run6From + have hinnerEq : innerCbrt x = run6From x (cbrtSeed x) := + innerCbrt_eq_run6From_seed x + calc innerCbrt x = run6From x (cbrtSeed x) := hinnerEq + _ = run6From x (seedOf i) := by rw [hseed] + _ ≤ m + 1 := hrun + +/-- Universal upper bound on uint256 domain: + for every x ∈ [1, 2^256-1], innerCbrt x ≤ icbrt x + 1. -/ +theorem innerCbrt_upper_u256 (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + innerCbrt x ≤ icbrt x + 1 := by + by_cases hx_small : x < 256 + · exact innerCbrt_upper_of_lt_256 x hx_small + · -- x ≥ 256, use the finite certificate + have hx256_le : 256 ≤ x := Nat.le_of_not_lt hx_small + -- Map x to octave: log2(x) ∈ [8, 255] + have hx0 : x ≠ 0 := Nat.ne_of_gt hx + let n := Nat.log2 x + have hn_bounds : 8 ≤ n := by + dsimp [n] + -- log2(x) ≥ 8 because x ≥ 256 = 2^8 + have hoctave := log2_octave x hx0 + -- log2(x) is the unique k with 2^k ≤ x < 2^(k+1) + -- Since 2^8 = 256 ≤ x, we need 8 ≤ log2(x) + -- Proof: if log2(x) < 8, then x < 2^(log2(x)+1) ≤ 2^8 = 256, contradiction + by_cases h8 : 8 ≤ Nat.log2 x + · exact h8 + · have hlt : Nat.log2 x + 1 ≤ 8 := by omega + have hup : x < 2 ^ (Nat.log2 x + 1) := hoctave.2 + have hpow : 2 ^ (Nat.log2 x + 1) ≤ 2 ^ 8 := + Nat.pow_le_pow_right (by decide : 1 ≤ 2) hlt + have : x < 256 := Nat.lt_of_lt_of_le hup (by simpa using hpow) + omega + have hn_lt : n < 256 := by + dsimp [n] + exact (Nat.log2_lt hx0).2 hx256 + -- Certificate index + have hcert : certOffset = 8 := rfl + let idx : Fin 248 := ⟨n - certOffset, by omega⟩ + have hidx_plus : idx.val + certOffset = n := by dsimp [idx]; omega + -- Octave membership + have hOct : 2 ^ (idx.val + certOffset) ≤ x ∧ x < 2 ^ (idx.val + certOffset + 1) := by + rw [hidx_plus] + exact log2_octave x hx0 + -- Apply the certificate + let m := icbrt x + have hmlo : m * m * m ≤ x := icbrt_cube_le x + have hmhi : x < (m + 1) * (m + 1) * (m + 1) := icbrt_lt_succ_cube x + exact innerCbrt_upper_of_octave idx x m hmlo hmhi hOct + +/-- Universal floor correctness on uint256 domain: + for every x ∈ [1, 2^256-1], floorCbrt x = icbrt x. -/ +theorem floorCbrt_correct_u256 (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + floorCbrt x = icbrt x := + floorCbrt_correct_of_upper x hx (innerCbrt_upper_u256 x hx hx256) + +/-- Universal floor correctness (including x = 0). -/ +theorem floorCbrt_correct_u256_all (x : Nat) (hx256 : x < 2 ^ 256) : + let r := floorCbrt x + r * r * r ≤ x ∧ x < (r + 1) * (r + 1) * (r + 1) := by + by_cases hx0 : x = 0 + · subst hx0; decide + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + have heq := floorCbrt_correct_u256 x hx hx256 + rw [heq] + exact ⟨icbrt_cube_le x, icbrt_lt_succ_cube x⟩ + +-- ============================================================================ +-- Perfect-cube exactness: innerCbrt(m³) = m +-- ============================================================================ + +/-- On a perfect cube, innerCbrt returns the exact cube root. + For m < 256: verified computationally (innerCbrt_on_perfect_cube_small). + For m ≥ 256: the certificate chain gives z₅ ≤ m + d5Of(i), and + d5Of(i)² < loOf(i) ≤ m (from d5_sq_lt_lo), so + cbrtStep_eq_on_perfect_cube_of_sq_lt gives z₆ = m. -/ +theorem innerCbrt_on_perfect_cube + (m : Nat) (hm : 0 < m) (hm256 : m * m * m < 2 ^ 256) : + innerCbrt (m * m * m) = m := by + by_cases hsmall : m < 256 + · exact innerCbrt_on_perfect_cube_small ⟨m, hsmall⟩ + · -- m ≥ 256 + have hm256_le : 256 ≤ m := Nat.le_of_not_lt hsmall + -- Lower bound: m ≤ innerCbrt(m³) + have hx_pos : 0 < m * m * m := Nat.mul_pos (Nat.mul_pos hm hm) hm + have hlo_inner : m ≤ innerCbrt (m * m * m) := + innerCbrt_lower (m * m * m) m hx_pos (Nat.le_refl _) + -- m³ < (m+1)³ (direct, not via icbrt) + have hm1_cube : m * m * m < (m + 1) * (m + 1) * (m + 1) := by + have h1 : m * m * m < m * m * (m + 1) := + Nat.mul_lt_mul_of_pos_left (by omega) (Nat.mul_pos hm hm) + have h2 : m * m ≤ (m + 1) * (m + 1) := Nat.mul_le_mul (Nat.le_succ m) (Nat.le_succ m) + exact Nat.lt_of_lt_of_le h1 (Nat.mul_le_mul_right _ h2) + -- icbrt(m³) = m + have hicbrt : icbrt (m * m * m) = m := + (icbrt_eq_of_bounds (m * m * m) m (Nat.le_refl _) hm1_cube).symm + -- Upper bound: innerCbrt(m³) ≤ m + 1 + have hhi_inner : innerCbrt (m * m * m) ≤ m + 1 := by + have := innerCbrt_upper_u256 (m * m * m) hx_pos hm256 + rw [hicbrt] at this; exact this + -- Rule out innerCbrt(m³) = m + 1 via certificate + by_cases heq : innerCbrt (m * m * m) = m + · exact heq + · exfalso + have heq1 : innerCbrt (m * m * m) = m + 1 := by omega + -- Map m³ to certificate octave + let x := m * m * m + have hx0 : x ≠ 0 := Nat.ne_of_gt hx_pos + let n := Nat.log2 x + have hn_bounds : 8 ≤ n := by + dsimp [n, x] + have hoctave := log2_octave x hx0 + by_cases h8 : 8 ≤ Nat.log2 x + · exact h8 + · have hlt : Nat.log2 x + 1 ≤ 8 := by omega + have hup : x < 2 ^ (Nat.log2 x + 1) := hoctave.2 + have hpow : 2 ^ (Nat.log2 x + 1) ≤ 2 ^ 8 := + Nat.pow_le_pow_right (by decide : 1 ≤ 2) hlt + have : x < 256 := Nat.lt_of_lt_of_le hup (by simpa using hpow) + -- But m ≥ 256, m³ ≥ 256³ > 256 + have : 256 * 256 * 256 ≤ x := + Nat.mul_le_mul (Nat.mul_le_mul hm256_le hm256_le) hm256_le + omega + have hn_lt : n < 256 := by + dsimp [n] + exact (Nat.log2_lt hx0).2 hm256 + have hcert : certOffset = 8 := rfl + let idx : Fin 248 := ⟨n - certOffset, by dsimp [n]; omega⟩ + have hidx_plus : idx.val + certOffset = n := by dsimp [idx, n]; omega + have hOct : 2 ^ (idx.val + certOffset) ≤ x ∧ x < 2 ^ (idx.val + certOffset + 1) := by + rw [hidx_plus] + exact log2_octave x hx0 + -- m is in [loOf idx, hiOf idx] + have hmlo_cube : m * m * m ≤ x := Nat.le_refl _ + have hmhi_cube : x < (m + 1) * (m + 1) * (m + 1) := hm1_cube + have hinterval := m_within_cert_interval idx x m hmlo_cube hmhi_cube hOct + have hm2 : 2 ≤ m := Nat.le_trans (by decide : 2 ≤ 256) hm256_le + -- Use shared 5-step certified bounds + have hseed : cbrtSeed x = seedOf idx := cbrtSeed_eq_certSeed idx x hOct + have ⟨hmz5, hd5, h2d5⟩ := run5_certified_bounds idx x m hm2 hmlo_cube hmhi_cube + hinterval.1 hinterval.2 + let z5 := run5From x (seedOf idx) + -- innerCbrt(x) = cbrtStep(x, z5) via run5From expansion + have hinner_run : innerCbrt x = cbrtStep x z5 := by + rw [innerCbrt_eq_step_run5_seed, hseed] + -- So cbrtStep(x, z5) = m + 1 + have hz6_eq : cbrtStep x z5 = m + 1 := by rw [← hinner_run]; exact heq1 + -- z₅ = m + e where e ≤ d5, e² < m, 2e ≤ m + have hd5sq_m : d5Of idx * d5Of idx < m := + Nat.lt_of_lt_of_le (d5_sq_lt_lo idx) hinterval.1 + have he_sq : (z5 - m) * (z5 - m) < m := + Nat.lt_of_le_of_lt (Nat.mul_le_mul hd5 hd5) hd5sq_m + have h2e : 2 * (z5 - m) ≤ m := + Nat.le_trans (Nat.mul_le_mul_left 2 hd5) h2d5 + -- cbrtStep(m³, z₅) = m via the perfect-cube lemma + have hz5_eq : z5 = m + (z5 - m) := by omega + have hz6_m : cbrtStep x z5 = m := by + show cbrtStep (m * m * m) z5 = m + rw [hz5_eq] + exact cbrtStep_eq_on_perfect_cube_of_sq_lt m (z5 - m) hm2 h2e he_sq + -- Contradiction: cbrtStep(x, z5) = m but also = m + 1 + omega + +end CbrtWiring diff --git a/formal/cbrt/CbrtProof/Main.lean b/formal/cbrt/CbrtProof/Main.lean new file mode 100644 index 000000000..0600de6d4 --- /dev/null +++ b/formal/cbrt/CbrtProof/Main.lean @@ -0,0 +1,52 @@ +import CbrtProof.GeneratedCbrtModel + +/-! +# Cbrt model evaluator + +Compiled executable for evaluating the generated EVM-faithful Cbrt model +on concrete inputs. Intended for fuzz testing via Foundry's `vm.ffi`. + +Usage: + cbrt-model + +Functions: cbrt, cbrt_floor, cbrt_up + +Output: 0x-prefixed hex uint256 on stdout. +-/ + +open CbrtGeneratedModel in +def evalFunction (name : String) (x : Nat) : Option Nat := + match name with + | "cbrt" => some (model_cbrt_evm x) + | "cbrt_floor" => some (model_cbrt_floor_evm x) + | "cbrt_up" => some (model_cbrt_up_evm x) + | _ => none + +def natToHex64 (n : Nat) : String := + let hex := String.ofList (Nat.toDigits 16 n) + "0x" ++ String.ofList (List.replicate (64 - hex.length) '0') ++ hex + +def parseHex (s : String) : Option Nat := + let s := if s.startsWith "0x" || s.startsWith "0X" then s.drop 2 else s + s.foldl (fun acc c => + acc.bind fun n => + if '0' ≤ c && c ≤ '9' then some (n * 16 + (c.toNat - '0'.toNat)) + else if 'a' ≤ c && c ≤ 'f' then some (n * 16 + (c.toNat - 'a'.toNat + 10)) + else if 'A' ≤ c && c ≤ 'F' then some (n * 16 + (c.toNat - 'A'.toNat + 10)) + else none + ) (some 0) + +def main (args : List String) : IO UInt32 := do + match args with + | [fnName, hexX] => + match parseHex hexX with + | none => IO.eprintln s!"Invalid hex input: {hexX}"; return 1 + | some x => + match evalFunction fnName x with + | none => IO.eprintln s!"Unknown function: {fnName}"; return 1 + | some result => + IO.println (natToHex64 result) + return 0 + | _ => + IO.eprintln "Usage: cbrt-model " + return 1 diff --git a/formal/cbrt/CbrtProof/lakefile.toml b/formal/cbrt/CbrtProof/lakefile.toml new file mode 100644 index 000000000..69bad8248 --- /dev/null +++ b/formal/cbrt/CbrtProof/lakefile.toml @@ -0,0 +1,10 @@ +name = "CbrtProof" +version = "0.1.0" +defaultTargets = ["CbrtProof"] + +[[lean_lib]] +name = "CbrtProof" + +[[lean_exe]] +name = "cbrt-model" +root = "Main" diff --git a/formal/cbrt/CbrtProof/lean-toolchain b/formal/cbrt/CbrtProof/lean-toolchain new file mode 100644 index 000000000..4c685fa08 --- /dev/null +++ b/formal/cbrt/CbrtProof/lean-toolchain @@ -0,0 +1 @@ +leanprover/lean4:v4.28.0 diff --git a/formal/cbrt/generate_cbrt_cert.py b/formal/cbrt/generate_cbrt_cert.py new file mode 100644 index 000000000..42aeb3189 --- /dev/null +++ b/formal/cbrt/generate_cbrt_cert.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +""" +Generate finite-certificate tables for the cbrt formal proof. + +For each of 256 octaves (n = 0..255), where octave n contains x in [2^n, 2^(n+1) - 1]: + - loOf(n) = icbrt(2^n) -- lower bound on icbrt(x) + - hiOf(n) = icbrt(2^(n+1) - 1) -- upper bound on icbrt(x) + - seedOf(n) = cbrt seed for octave n + - d1(n): first-step error bound from algebraic formula + - d2..d6(n): chained via nextD(lo, d) = d^2/lo + 1 + +The d1 bound uses the cubic identity: + 3*s^2*(z1 - m) <= (m-s)^2*(m+2s) + 3*m*(m+1) + <= maxAbs^2*(hi+2s) + 3*hi*(hi+1) +where maxAbs = max(|s-lo|, |hi-s|). + +Octaves 0-7 (x < 256) are handled separately by decide in Lean. +The certificate covers octaves 8-255 (x >= 256, lo >= 6). +""" + +import argparse +import sys + + +def icbrt(x): + """Integer cube root (floor).""" + if x <= 0: + return 0 + if x < 8: + return 1 + n = x.bit_length() + z = 1 << ((n + 2) // 3) + while True: + z1 = (2 * z + x // (z * z)) // 3 + if z1 >= z: + break + z = z1 + while z * z * z > x: + z -= 1 + while (z + 1) ** 3 <= x: + z += 1 + return z + + +def cbrt_step(x, z): + """One NR step: floor((floor(x/(z*z)) + 2*z) / 3)""" + if z == 0: + return 0 + return (x // (z * z) + 2 * z) // 3 + + +def cbrt_seed(n): + """Seed for octave n.""" + q = (n + 2) // 3 + return ((0xe9 << q) >> 8) + 1 + + +def next_d(lo, d): + """Error recurrence: d^2/lo + 1.""" + if lo == 0: + return d * d + 1 + return d * d // lo + 1 + + +def compute_maxabs(lo, hi, s): + """max(|s - lo|, |hi - s|)""" + return max(abs(s - lo), abs(hi - s)) + + +def compute_d1(lo, hi, s): + """Analytic d1 bound: + d1 = floor((maxAbs^2*(hi+2s) + 3*hi*(hi+1)) / (3*s^2)) + """ + maxAbs = compute_maxabs(lo, hi, s) + numerator = maxAbs * maxAbs * (hi + 2 * s) + 3 * hi * (hi + 1) + denominator = 3 * s * s + if denominator == 0: + return 0 + return numerator // denominator + + +def main(): + parser = argparse.ArgumentParser( + description="Generate finite-certificate tables for cbrt formal proof" + ) + parser.add_argument( + "--output", + default="CbrtProof/CbrtProof/FiniteCert.lean", + help="Output Lean file path (default: CbrtProof/CbrtProof/FiniteCert.lean)", + ) + args = parser.parse_args() + + lo_table = [] + hi_table = [] + + for n in range(256): + lo = icbrt(1 << n) + hi = icbrt((1 << (n + 1)) - 1) + lo_table.append(lo) + hi_table.append(hi) + + # Verify basic properties + for n in range(256): + lo = lo_table[n] + hi = hi_table[n] + assert lo * lo * lo <= (1 << n), f"lo^3 > 2^n at n={n}" + assert (1 << (n + 1)) <= (hi + 1) ** 3, f"2^(n+1) > (hi+1)^3 at n={n}" + assert lo <= hi, f"lo > hi at n={n}" + + # Compute certificate for octaves 8-255 + START_OCTAVE = 8 + all_ok = True + d_data = {} # n -> (d1, ..., d6) + + for n in range(START_OCTAVE, 256): + lo = lo_table[n] + hi = hi_table[n] + seed = cbrt_seed(n) + + d1 = compute_d1(lo, hi, seed) + d2 = next_d(lo, d1) + d3 = next_d(lo, d2) + d4 = next_d(lo, d3) + d5 = next_d(lo, d4) + d6 = next_d(lo, d5) + d_data[n] = (d1, d2, d3, d4, d5, d6) + + if d6 > 1: + print(f"FAIL d6: n={n}, d1={d1}, d2={d2}, d3={d3}, " + f"d4={d4}, d5={d5}, d6={d6}, lo={lo}") + all_ok = False + + # Check side conditions: 2*dk <= lo for k=1..5 + for k, dk in enumerate([d1, d2, d3, d4, d5], 1): + if 2 * dk > lo: + print(f"SIDE FAIL: n={n}, 2*d{k}={2*dk} > lo={lo}") + all_ok = False + + if all_ok: + print(f"All octaves {START_OCTAVE}-255 pass: d6 <= 1, all side conditions OK.") + else: + print("SOME OCTAVES FAIL.") + + # Exhaustive verification for small octaves to confirm d1 bound + print(f"\nExhaustive verification of d1 for octaves {START_OCTAVE}-30...") + for n in range(START_OCTAVE, min(31, 256)): + lo = lo_table[n] + hi = hi_table[n] + seed = cbrt_seed(n) + d1_cert = d_data[n][0] + + for m in range(lo, hi + 1): + x_lo_m = max(m * m * m, 1 << n) + x_hi_m = min((m + 1) ** 3 - 1, (1 << (n + 1)) - 1) + if x_lo_m > x_hi_m: + continue + z1 = cbrt_step(x_hi_m, seed) # max z1 by mono in x + actual_d1 = max(0, z1 - m) + if actual_d1 > d1_cert: + print(f" D1 FAIL: n={n}, m={m}, z1={z1}, actual_d1={actual_d1}, cert={d1_cert}") + all_ok = False + print(" d1 exhaustive check done.") + + # Spot-check d1 for large octaves (random m values) + import random + random.seed(42) + print("\nSpot-checking d1 for large octaves...") + for n in range(100, 256, 10): + lo = lo_table[n] + hi = hi_table[n] + seed = cbrt_seed(n) + d1_cert = d_data[n][0] + + # Test at lo, hi, and random m values + for m in [lo, hi, lo + (hi - lo) // 3, lo + 2 * (hi - lo) // 3]: + x_max = min((m + 1) ** 3 - 1, (1 << (n + 1)) - 1) + x_min = max(m ** 3, 1 << n) + if x_min > x_max: + continue + z1 = cbrt_step(x_max, seed) + actual_d1 = max(0, z1 - m) + if actual_d1 > d1_cert: + print(f" SPOT FAIL: n={n}, m={m}, z1={z1}, actual_d1={actual_d1}, cert={d1_cert}") + all_ok = False + print(" Spot check done.") + + # Also check lo_pos: lo >= 6 for octaves >= 8 + assert all(lo_table[n] >= 6 for n in range(START_OCTAVE, 256)), \ + "lo < 6 for some octave >= 8!" + print(f"\nAll lo >= 6 for octaves >= {START_OCTAVE}. ✓") + + # Also check 2 <= lo (needed for cbrtStep_upper_of_le) + assert all(lo_table[n] >= 2 for n in range(START_OCTAVE, 256)), \ + "lo < 2 for some octave >= 8!" + + # Summary + print(f"\n--- Summary (octaves {START_OCTAVE}-255) ---") + for k in range(6): + vals = [d_data[n][k] for n in range(START_OCTAVE, 256)] + mx = max(vals) + mi = START_OCTAVE + vals.index(mx) + print(f" Max d{k+1}: {mx} at n={mi}") + + # Print d1/lo ratios for a few octaves + print(f"\n--- d1/lo ratios ---") + for n in [8, 10, 20, 50, 85, 100, 123, 170, 200, 255]: + if n >= START_OCTAVE: + lo = lo_table[n] + d1 = d_data[n][0] + print(f" n={n}: lo={lo}, d1={d1}, d1/lo={d1/lo:.6f}, 2d1/lo={2*d1/lo:.6f}") + + # Generate Lean output + if all_ok: + generate_lean_file(lo_table, hi_table, d_data, START_OCTAVE, args.output) + + return 0 if all_ok else 1 + + +def generate_lean_file(lo_table, hi_table, d_data, start_octave, outpath): + """Generate the CbrtFiniteCert.lean file.""" + print(f"\nGenerating {outpath}...") + + num = 256 - start_octave # 248 entries + + def fmt_array(name, values, comment=""): + lines = [] + if comment: + lines.append(f"/-- {comment} -/") + lines.append(f"def {name} : Array Nat := #[") + for i, v in enumerate(values): + comma = "," if i < len(values) - 1 else "" + lines.append(f" {v}{comma}") + lines.append("]") + return "\n".join(lines) + + lo_vals = lo_table[start_octave:] + hi_vals = hi_table[start_octave:] + seed_vals = [cbrt_seed(n) for n in range(start_octave, 256)] + d1_vals = [d_data[n][0] for n in range(start_octave, 256)] + + # Compute maxAbs values for the d1 bound proof + maxabs_vals = [compute_maxabs(lo_table[n], hi_table[n], cbrt_seed(n)) + for n in range(start_octave, 256)] + + content = f"""import Init + +/- + Finite certificate for cbrt upper bound, covering octaves {start_octave}..255. + + For each octave i (offset from {start_octave}), the tables provide: + - loOf(i): lower bound on icbrt(x) for x in [2^(i+{start_octave}), 2^(i+{start_octave+1})-1] + - hiOf(i): upper bound on icbrt(x) + - seedOf(i): the cbrt seed for the octave + - maxAbsOf(i): max|seed - m| for m in [lo, hi] + - d1Of(i): first-step error bound (analytic) + - nextD, d2..d6: chained error recurrence + + All 248 octaves verified: d6 <= 1 and 2*dk <= lo for k=1..5. +-/ + +namespace CbrtCert + +set_option maxRecDepth 1000000 + +/-- Offset: certificate octave index i corresponds to bit-length octave i + {start_octave}. -/ +def certOffset : Nat := {start_octave} + +{fmt_array("loTable", lo_vals, f"Lower bounds on icbrt(x) for octaves {start_octave}..255.")} + +{fmt_array("hiTable", hi_vals, f"Upper bounds on icbrt(x) for octaves {start_octave}..255.")} + +{fmt_array("seedTable", seed_vals, f"cbrt seed for octaves {start_octave}..255.")} + +{fmt_array("maxAbsTable", maxabs_vals, "max(|seed - lo|, |hi - seed|) per octave.")} + +{fmt_array("d1Table", d1_vals, "First-step error bound per octave.")} + +def loOf (i : Fin {num}) : Nat := loTable[i.val]! +def hiOf (i : Fin {num}) : Nat := hiTable[i.val]! +def seedOf (i : Fin {num}) : Nat := seedTable[i.val]! +def maxAbsOf (i : Fin {num}) : Nat := maxAbsTable[i.val]! +def d1Of (i : Fin {num}) : Nat := d1Table[i.val]! + +/-- Error recurrence: d^2/lo + 1. -/ +def nextD (lo d : Nat) : Nat := d * d / lo + 1 + +def d2Of (i : Fin {num}) : Nat := nextD (loOf i) (d1Of i) +def d3Of (i : Fin {num}) : Nat := nextD (loOf i) (d2Of i) +def d4Of (i : Fin {num}) : Nat := nextD (loOf i) (d3Of i) +def d5Of (i : Fin {num}) : Nat := nextD (loOf i) (d4Of i) +def d6Of (i : Fin {num}) : Nat := nextD (loOf i) (d5Of i) + +-- ============================================================================ +-- Computational verification of certificate properties +-- ============================================================================ + +/-- lo is always positive. -/ +theorem lo_pos : ∀ i : Fin {num}, 0 < loOf i := by decide + +/-- lo >= 2 (needed for cbrtStep_upper_of_le). -/ +theorem lo_ge_two : ∀ i : Fin {num}, 2 ≤ loOf i := by decide + +/-- lo <= hi. -/ +theorem lo_le_hi : ∀ i : Fin {num}, loOf i ≤ hiOf i := by decide + +/-- seed is positive. -/ +theorem seed_pos : ∀ i : Fin {num}, 0 < seedOf i := by decide + +/-- lo^3 <= 2^(i + certOffset). -/ +theorem lo_cube_le_pow2 : ∀ i : Fin {num}, + loOf i * loOf i * loOf i ≤ 2 ^ (i.val + certOffset) := by decide + +/-- 2^(i + certOffset + 1) <= (hi+1)^3. -/ +theorem pow2_succ_le_hi_succ_cube : ∀ i : Fin {num}, + 2 ^ (i.val + certOffset + 1) ≤ (hiOf i + 1) * (hiOf i + 1) * (hiOf i + 1) := by decide + +/-- d1 is the correct analytic bound: + d1Of(i) = (maxAbsOf(i)^2 * (hiOf(i) + 2*seedOf(i)) + 3*hiOf(i)*(hiOf(i)+1)) / (3*seedOf(i)^2) -/ +theorem d1_eq : ∀ i : Fin {num}, + d1Of i = (maxAbsOf i * maxAbsOf i * (hiOf i + 2 * seedOf i) + + 3 * hiOf i * (hiOf i + 1)) / (3 * (seedOf i * seedOf i)) := by decide + +/-- maxAbs captures the correct value. -/ +theorem maxabs_eq : ∀ i : Fin {num}, + maxAbsOf i = max (seedOf i - loOf i) (hiOf i - seedOf i) := by decide + +/-- Terminal bound: d6 <= 1 for all certificate octaves. -/ +theorem d6_le_one : ∀ i : Fin {num}, d6Of i ≤ 1 := by decide + +/-- Side condition: 2 * d1 <= lo. -/ +theorem two_d1_le_lo : ∀ i : Fin {num}, 2 * d1Of i ≤ loOf i := by decide + +/-- Side condition: 2 * d2 <= lo. -/ +theorem two_d2_le_lo : ∀ i : Fin {num}, 2 * d2Of i ≤ loOf i := by decide + +/-- Side condition: 2 * d3 <= lo. -/ +theorem two_d3_le_lo : ∀ i : Fin {num}, 2 * d3Of i ≤ loOf i := by decide + +/-- Side condition: 2 * d4 <= lo. -/ +theorem two_d4_le_lo : ∀ i : Fin {num}, 2 * d4Of i ≤ loOf i := by decide + +/-- Side condition: 2 * d5 <= lo. -/ +theorem two_d5_le_lo : ∀ i : Fin {num}, 2 * d5Of i ≤ loOf i := by decide + +/-- Seed matches the cbrt seed formula: + seedOf(i) = ((0xe9 <<< ((i + certOffset + 2) / 3)) >>> 8) + 1 -/ +theorem seed_eq : ∀ i : Fin {num}, + seedOf i = ((0xe9 <<< ((i.val + certOffset + 2) / 3)) >>> 8) + 1 := by decide + +/-- Perfect-cube key: d5² < lo for all certificate octaves. + This ensures that on perfect cubes x = m³, the 6th NR step gives exactly m + (since the per-step error d²/m < 1 when d² < m and m ≥ lo). -/ +theorem d5_sq_lt_lo : ∀ i : Fin {num}, d5Of i * d5Of i < loOf i := by decide + +end CbrtCert +""" + + with open(outpath, "w") as f: + f.write(content) + print(f" Written to {outpath}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/formal/cbrt/generate_cbrt_model.py b/formal/cbrt/generate_cbrt_model.py new file mode 100755 index 000000000..40b27d88d --- /dev/null +++ b/formal/cbrt/generate_cbrt_model.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +Generate Lean models of Cbrt.sol from Yul IR. + +This script extracts `_cbrt`, `cbrt`, and `cbrtUp` from the Yul IR produced by +`forge inspect` on a wrapper contract and emits Lean definitions for: +- opcode-faithful uint256 EVM semantics, and +- normalized Nat semantics. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Allow importing the shared module from formal/ +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from yul_to_lean import Call, Expr, IntLit, ModelConfig, run + + +def rewrite_norm_ast(expr: Expr) -> Expr: + """Rewrite sub(257, clz(arg)) → bitLengthPlus1(arg) for the Nat model. + + In Nat arithmetic, normSub 257 (normClz x) = 257 - (255 - log2 x) underflows + for x ≥ 2^256 because 255 - log2 x truncates to 0. normBitLengthPlus1(x) + computes log2(x) + 2 directly, giving the correct value for all Nat. + """ + if isinstance(expr, Call): + args = tuple(rewrite_norm_ast(a) for a in expr.args) + if ( + expr.name == "sub" + and len(args) == 2 + and isinstance(args[0], IntLit) + and args[0].value == 257 + and isinstance(args[1], Call) + and args[1].name == "clz" + and len(args[1].args) == 1 + ): + return Call("bitLengthPlus1", args[1].args) + return Call(expr.name, args) + return expr + + +CONFIG = ModelConfig( + function_order=("_cbrt", "cbrt", "cbrtUp"), + model_names={ + "_cbrt": "model_cbrt", + "cbrt": "model_cbrt_floor", + "cbrtUp": "model_cbrt_up", + }, + header_comment="Auto-generated from Solidity Cbrt assembly and assignment flow.", + generator_label="formal/cbrt/generate_cbrt_model.py", + extra_norm_ops={"bitLengthPlus1": "normBitLengthPlus1"}, + extra_lean_defs=( + "def normBitLengthPlus1 (value : Nat) : Nat :=\n" + " if value = 0 then 1 else Nat.log2 value + 2\n\n" + ), + norm_rewrite=rewrite_norm_ast, + inner_fn="_cbrt", + default_source_label="src/vendor/Cbrt.sol", + default_namespace="CbrtGeneratedModel", + default_output="formal/cbrt/CbrtProof/CbrtProof/GeneratedCbrtModel.lean", + cli_description="Generate Lean model of Cbrt.sol functions from Yul IR", +) + + +if __name__ == "__main__": + raise SystemExit(run(CONFIG)) diff --git a/formal/sqrt/Sqrt512Proof/.gitignore b/formal/sqrt/Sqrt512Proof/.gitignore new file mode 100644 index 000000000..e0793a203 --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/.gitignore @@ -0,0 +1,5 @@ +/.lake +lake-manifest.json + +# Auto-generated from `formal/sqrt/generate_sqrt512_model.py` +/Sqrt512Proof/GeneratedSqrt512Model.lean diff --git a/formal/sqrt/Sqrt512Proof/Main.lean b/formal/sqrt/Sqrt512Proof/Main.lean new file mode 100644 index 000000000..de2462036 --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Main.lean @@ -0,0 +1,70 @@ +import Sqrt512Proof.GeneratedSqrt512Model + +/-! +# Sqrt512 model evaluator + +Compiled executable for evaluating the generated EVM-faithful 512-bit +Sqrt model on concrete inputs. Intended for fuzz testing via Foundry's +`vm.ffi`. + +Usage: + sqrt512-model sqrt512 → 1 hex word + sqrt512-model sqrt512_wrapper → 1 hex word + sqrt512-model osqrtUp → 2 hex words (ABI-encoded) +-/ + +open Sqrt512GeneratedModel in +def evalFunction1 (name : String) (xHi xLo : Nat) : Option Nat := + match name with + | "sqrt512" => some (model_sqrt512_evm xHi xLo) + | "sqrt512_wrapper" => some (model_sqrt512_wrapper_evm xHi xLo) + | _ => none + +open Sqrt512GeneratedModel in +def evalFunction2 (name : String) (xHi xLo : Nat) : Option (Nat × Nat) := + match name with + | "osqrtUp" => some (model_osqrtUp_evm xHi xLo) + | _ => none + +def natToHex64 (n : Nat) : String := + let hex := String.ofList (Nat.toDigits 16 n) + "0x" ++ String.ofList (List.replicate (64 - hex.length) '0') ++ hex + +def parseHex (s : String) : Option Nat := + let s := if s.startsWith "0x" || s.startsWith "0X" then s.drop 2 else s + s.foldl (fun acc c => + acc.bind fun n => + if '0' ≤ c && c ≤ '9' then some (n * 16 + (c.toNat - '0'.toNat)) + else if 'a' ≤ c && c ≤ 'f' then some (n * 16 + (c.toNat - 'a'.toNat + 10)) + else if 'A' ≤ c && c ≤ 'F' then some (n * 16 + (c.toNat - 'A'.toNat + 10)) + else none + ) (some 0) + +def main (args : List String) : IO UInt32 := do + match args with + | [fnName, hexHi, hexLo] => + match parseHex hexHi, parseHex hexLo with + | some hi, some lo => + -- Try single-word functions first + match evalFunction1 fnName hi lo with + | some result => + -- ABI-encode as single uint256: 32 bytes zero-padded + IO.println (natToHex64 result) + return 0 + | none => + -- Try two-word functions + match evalFunction2 fnName hi lo with + | some (rHi, rLo) => + -- ABI-encode as (uint256, uint256): 64 bytes + -- Foundry's vm.ffi decodes the stdout as raw ABI bytes + IO.println (natToHex64 rHi ++ (natToHex64 rLo).drop 2) + return 0 + | none => + IO.eprintln s!"Unknown function: {fnName}" + return 1 + | _, _ => + IO.eprintln s!"Invalid hex input" + return 1 + | _ => + IO.eprintln "Usage: sqrt512-model " + return 1 diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof.lean new file mode 100644 index 000000000..e8c4e814a --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof.lean @@ -0,0 +1,9 @@ +-- Root of the Sqrt512Proof library. +import Sqrt512Proof.Normalization +import Sqrt512Proof.KaratsubaStep +import Sqrt512Proof.Correction +import Sqrt512Proof.Sqrt512Correct +import Sqrt512Proof.SqrtUpCorrect +import Sqrt512Proof.GeneratedSqrt512Spec +import Sqrt512Proof.SqrtWrapperSpec +import Sqrt512Proof.OsqrtUpSpec diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Correction.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Correction.lean new file mode 100644 index 000000000..f0bddfa75 --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Correction.lean @@ -0,0 +1,40 @@ +/- + Correction step for 512-bit square root. + + After Karatsuba, r in {natSqrt(x), natSqrt(x) + 1}. + Checking x < r^2 and decrementing gives exactly natSqrt(x). +-/ +import SqrtProof.SqrtCorrect + +/-- If natSqrt(x) <= r <= natSqrt(x) + 1, then + (if x < r^2 then r-1 else r) = natSqrt(x). -/ +theorem correction_correct (x r : Nat) + (hlo : natSqrt x ≤ r) (hhi : r ≤ natSqrt x + 1) : + (if x < r * r then r - 1 else r) = natSqrt x := by + have hmlo : natSqrt x * natSqrt x ≤ x := natSqrt_sq_le x + have hmhi : x < (natSqrt x + 1) * (natSqrt x + 1) := natSqrt_lt_succ_sq x + -- r is either natSqrt x or natSqrt x + 1 + have hrm : r = natSqrt x ∨ r = natSqrt x + 1 := by omega + rcases hrm with rfl | rfl + · -- r = natSqrt x: (natSqrt x)^2 <= x so not (x < (natSqrt x)^2) + simp [Nat.not_lt.mpr hmlo] + · -- r = natSqrt x + 1: x < (natSqrt x + 1)^2, so decrement + simp [hmhi] + +/-- Correction produces the natSqrt spec. -/ +theorem correction_spec (x r : Nat) + (hlo : natSqrt x ≤ r) (hhi : r ≤ natSqrt x + 1) : + let r' := if x < r * r then r - 1 else r + r' * r' ≤ x ∧ x < (r' + 1) * (r' + 1) := by + have h := correction_correct x r hlo hhi + intro r' + -- r' = natSqrt x + have hr'_eq : r' = natSqrt x := h + rw [hr'_eq] + exact ⟨natSqrt_sq_le x, natSqrt_lt_succ_sq x⟩ + +/-- From the Karatsuba identity x + q^2 = r^2 + rem*H + x_lo_lo, + x < r^2 <-> rem*H + x_lo_lo < q^2. -/ +theorem correction_equiv (x q r rem_H x_lo_lo : Nat) + (hident : x + q * q = r * r + rem_H + x_lo_lo) : + (x < r * r) ↔ (rem_H + x_lo_lo < q * q) := by omega diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/GeneratedSqrt512Spec.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/GeneratedSqrt512Spec.lean new file mode 100644 index 000000000..bc8e672eb --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/GeneratedSqrt512Spec.lean @@ -0,0 +1,1684 @@ +/- + Bridge from model_sqrt512_evm to natSqrt: specification layer. + + Part 1 (fully proved): Fixed-seed convergence certificate. + Part 2: EVM model bridge — model_sqrt512_evm = sqrt512. + Part 3 (fully proved): Composition — sqrt512 = natSqrt. + + Architecture: model_sqrt512_evm →[direct EVM bridge]→ sqrt512 →[proved]→ natSqrt + + Note: The auto-generated norm model (model_sqrt512) uses unbounded Nat operations + (normShl, normMul) which do NOT match EVM uint256 semantics. Therefore we prove the + EVM model correct directly, without factoring through the norm model. +-/ +import Sqrt512Proof.Sqrt512Correct +import Sqrt512Proof.GeneratedSqrt512Model + +namespace Sqrt512Spec + +open SqrtCert +open SqrtBridge +open SqrtCertified +open Sqrt512Cert hiding FIXED_SEED + +-- ============================================================================ +-- Section 1: Fixed-seed definitions +-- ============================================================================ + +/-- The fixed Newton seed used by 512-bit sqrt: floor(sqrt(2^255)). + Equals hiOf(254) = loOf(255) in the finite certificate tables. -/ +def FIXED_SEED : Nat := 240615969168004511545033772477625056927 + +theorem fixed_seed_pos : 0 < FIXED_SEED := by decide + +/-- Run 6 Babylonian steps from the fixed seed. -/ +def run6Fixed (x : Nat) : Nat := + let z := bstep x FIXED_SEED + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + z + +/-- Floor square root using the fixed seed: 6 Newton steps + correction. -/ +def floorSqrt_fixed (x : Nat) : Nat := + let z := run6Fixed x + if z = 0 then 0 else if x / z < z then z - 1 else z + +-- ============================================================================ +-- Section 2: Fixed-seed convergence for octave 254 (x ∈ [2^254, 2^255)) +-- Certificate definitions (lo254, fd1_254, etc.) are in Sqrt512Cert +-- (auto-generated in SqrtProof/SqrtProof/FiniteCert.lean). +-- ============================================================================ + +set_option maxRecDepth 100000 in +private theorem run6Fixed_error_254 + (x m : Nat) (hm : 0 < m) (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) (hlo : lo254 ≤ m) (hhi : m ≤ hi254) : + run6Fixed x - m ≤ fd6_254 := by + let z1 := bstep x FIXED_SEED; let z2 := bstep x z1; let z3 := bstep x z2 + let z4 := bstep x z3; let z5 := bstep x z4; let z6 := bstep x z5 + have hmz1 : m ≤ z1 := babylon_step_floor_bound x FIXED_SEED m fixed_seed_pos hmlo + have hz1P : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hmz2 : m ≤ z2 := babylon_step_floor_bound x z1 m hz1P hmlo + have hz2P : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hmz3 : m ≤ z3 := babylon_step_floor_bound x z2 m hz2P hmlo + have hz3P : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hmz4 : m ≤ z4 := babylon_step_floor_bound x z3 m hz3P hmlo + have hz4P : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hmz5 : m ≤ z5 := babylon_step_floor_bound x z4 m hz4P hmlo + have hd1 : z1 - m ≤ fd1_254 := by + simpa [z1, fd1_254, maxAbs254] using + d1_bound x m FIXED_SEED lo254 hi254 fixed_seed_pos hmlo hmhi hlo hhi + have hd1m : fd1_254 ≤ m := Nat.le_trans fd1_254_le_lo hlo + have hd2 : z2 - m ≤ fd2_254 := by + simpa [z2, fd2_254] using step_from_bound x m lo254 z1 fd1_254 hm lo254_pos hlo hmhi hmz1 hd1 hd1m + have hd2m : fd2_254 ≤ m := Nat.le_trans fd2_254_le_lo hlo + have hd3 : z3 - m ≤ fd3_254 := by + simpa [z3, fd3_254] using step_from_bound x m lo254 z2 fd2_254 hm lo254_pos hlo hmhi hmz2 hd2 hd2m + have hd3m : fd3_254 ≤ m := Nat.le_trans fd3_254_le_lo hlo + have hd4 : z4 - m ≤ fd4_254 := by + simpa [z4, fd4_254] using step_from_bound x m lo254 z3 fd3_254 hm lo254_pos hlo hmhi hmz3 hd3 hd3m + have hd4m : fd4_254 ≤ m := Nat.le_trans fd4_254_le_lo hlo + have hd5 : z5 - m ≤ fd5_254 := by + simpa [z5, fd5_254] using step_from_bound x m lo254 z4 fd4_254 hm lo254_pos hlo hmhi hmz4 hd4 hd4m + have hd5m : fd5_254 ≤ m := Nat.le_trans fd5_254_le_lo hlo + have hd6 : z6 - m ≤ fd6_254 := by + simpa [z6, fd6_254] using step_from_bound x m lo254 z5 fd5_254 hm lo254_pos hlo hmhi hmz5 hd5 hd5m + simpa [run6Fixed, z1, z2, z3, z4, z5, z6] using hd6 + +-- ============================================================================ +-- Section 3: Fixed-seed convergence for octave 255 (x ∈ [2^255, 2^256)) +-- Certificate definitions (lo255, fd1_255, etc.) are in Sqrt512Cert. +-- ============================================================================ + +set_option maxRecDepth 100000 in +private theorem run6Fixed_error_255 + (x m : Nat) (hm : 0 < m) (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) (hlo : lo255 ≤ m) (hhi : m ≤ hi255) : + run6Fixed x - m ≤ fd6_255 := by + let z1 := bstep x FIXED_SEED; let z2 := bstep x z1; let z3 := bstep x z2 + let z4 := bstep x z3; let z5 := bstep x z4; let z6 := bstep x z5 + have hmz1 : m ≤ z1 := babylon_step_floor_bound x FIXED_SEED m fixed_seed_pos hmlo + have hz1P : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hmz2 : m ≤ z2 := babylon_step_floor_bound x z1 m hz1P hmlo + have hz2P : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hmz3 : m ≤ z3 := babylon_step_floor_bound x z2 m hz2P hmlo + have hz3P : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hmz4 : m ≤ z4 := babylon_step_floor_bound x z3 m hz3P hmlo + have hz4P : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hmz5 : m ≤ z5 := babylon_step_floor_bound x z4 m hz4P hmlo + have hd1 : z1 - m ≤ fd1_255 := by + simpa [z1, fd1_255, maxAbs255] using + d1_bound x m FIXED_SEED lo255 hi255 fixed_seed_pos hmlo hmhi hlo hhi + have hd1m : fd1_255 ≤ m := Nat.le_trans fd1_255_le_lo hlo + have hd2 : z2 - m ≤ fd2_255 := by + simpa [z2, fd2_255] using step_from_bound x m lo255 z1 fd1_255 hm lo255_pos hlo hmhi hmz1 hd1 hd1m + have hd2m : fd2_255 ≤ m := Nat.le_trans fd2_255_le_lo hlo + have hd3 : z3 - m ≤ fd3_255 := by + simpa [z3, fd3_255] using step_from_bound x m lo255 z2 fd2_255 hm lo255_pos hlo hmhi hmz2 hd2 hd2m + have hd3m : fd3_255 ≤ m := Nat.le_trans fd3_255_le_lo hlo + have hd4 : z4 - m ≤ fd4_255 := by + simpa [z4, fd4_255] using step_from_bound x m lo255 z3 fd3_255 hm lo255_pos hlo hmhi hmz3 hd3 hd3m + have hd4m : fd4_255 ≤ m := Nat.le_trans fd4_255_le_lo hlo + have hd5 : z5 - m ≤ fd5_255 := by + simpa [z5, fd5_255] using step_from_bound x m lo255 z4 fd4_255 hm lo255_pos hlo hmhi hmz4 hd4 hd4m + have hd5m : fd5_255 ≤ m := Nat.le_trans fd5_255_le_lo hlo + have hd6 : z6 - m ≤ fd6_255 := by + simpa [z6, fd6_255] using step_from_bound x m lo255 z5 fd5_255 hm lo255_pos hlo hmhi hmz5 hd5 hd5m + simpa [run6Fixed, z1, z2, z3, z4, z5, z6] using hd6 + +-- ============================================================================ +-- Section 4: Combined fixed-seed bracket + floor correction +-- ============================================================================ + +private theorem m_le_run6Fixed (x m : Nat) (hx : 0 < x) (hmlo : m * m ≤ x) : + m ≤ run6Fixed x := by + let z1 := bstep x FIXED_SEED; let z2 := bstep x z1; let z3 := bstep x z2 + let z4 := bstep x z3; let z5 := bstep x z4; let z6 := bstep x z5 + have hz1 : 0 < z1 := bstep_pos x FIXED_SEED hx fixed_seed_pos + have hz2 : 0 < z2 := bstep_pos x z1 hx hz1 + have hz3 : 0 < z3 := bstep_pos x z2 hx hz2 + have hz4 : 0 < z4 := bstep_pos x z3 hx hz3 + have hz5 : 0 < z5 := bstep_pos x z4 hx hz4 + simpa [run6Fixed, z1, z2, z3, z4, z5, z6] using + babylon_step_floor_bound x z5 m hz5 hmlo + +theorem fixed_seed_bracket (x : Nat) (hlo : 2 ^ 254 ≤ x) (hhi : x < 2 ^ 256) : + natSqrt x ≤ run6Fixed x ∧ run6Fixed x ≤ natSqrt x + 1 := by + have hmlo := natSqrt_sq_le x + have hmhi := natSqrt_lt_succ_sq x + have hm : 0 < natSqrt x := by + suffices natSqrt x ≠ 0 by omega + intro h0; have := natSqrt_lt_succ_sq x; rw [h0] at this; omega + constructor + · exact m_le_run6Fixed x (natSqrt x) (by omega) hmlo + · suffices run6Fixed x - natSqrt x ≤ 1 by omega + by_cases hlt : x < 2 ^ 255 + · have hOct : 2 ^ (254 : Fin 256).val ≤ x ∧ x < 2 ^ ((254 : Fin 256).val + 1) := ⟨hlo, hlt⟩ + have hint := m_within_cert_interval ⟨254, by omega⟩ x (natSqrt x) hmlo hmhi hOct + exact Nat.le_trans (run6Fixed_error_254 x (natSqrt x) hm hmlo hmhi hint.1 hint.2) fd6_254_le_one + · have h255 : 2 ^ 255 ≤ x := Nat.le_of_not_lt hlt + have hOct : 2 ^ (⟨255, by omega⟩ : Fin 256).val ≤ x ∧ + x < 2 ^ ((⟨255, by omega⟩ : Fin 256).val + 1) := ⟨h255, hhi⟩ + have hint := m_within_cert_interval ⟨255, by omega⟩ x (natSqrt x) hmlo hmhi hOct + exact Nat.le_trans (run6Fixed_error_255 x (natSqrt x) hm hmlo hmhi hint.1 hint.2) fd6_255_le_one + +theorem floorSqrt_fixed_eq_natSqrt (x : Nat) (hlo : 2 ^ 254 ≤ x) (hhi : x < 2 ^ 256) : + floorSqrt_fixed x = natSqrt x := by + have hbr := fixed_seed_bracket x hlo hhi + have hz_pos : 0 < run6Fixed x := by + have hm_pos : 0 < natSqrt x := by + suffices natSqrt x ≠ 0 by omega + intro h0; have := natSqrt_lt_succ_sq x; rw [h0] at this; omega + exact Nat.lt_of_lt_of_le hm_pos hbr.1 + have hcorr := correction_correct x (run6Fixed x) hbr.1 hbr.2 + have h1 : floorSqrt_fixed x = + (if x / run6Fixed x < run6Fixed x then run6Fixed x - 1 else run6Fixed x) := by + unfold floorSqrt_fixed; dsimp only []; exact if_neg (Nat.ne_of_gt hz_pos) + rw [h1] + simp only [show (x / run6Fixed x < run6Fixed x) = (x < run6Fixed x * run6Fixed x) from + propext (Nat.div_lt_iff_lt_mul hz_pos)] + exact hcorr + +-- ============================================================================ +-- Section 5: EVM operation simplification helpers +-- ============================================================================ + +open Sqrt512GeneratedModel in +/-- normAdd (unbounded) is just addition. -/ +private theorem normAdd_eq (a b : Nat) : normAdd a b = a + b := rfl + +open Sqrt512GeneratedModel in +/-- normShr is division by power of 2. -/ +private theorem normShr_eq (s v : Nat) : normShr s v = v / 2 ^ s := rfl + +open Sqrt512GeneratedModel in +/-- normDiv is Nat division. -/ +private theorem normDiv_eq (a b : Nat) : normDiv a b = a / b := rfl + +open Sqrt512GeneratedModel in +/-- normMod is Nat modulo. -/ +private theorem normMod_eq (a b : Nat) : normMod a b = a % b := rfl + +open Sqrt512GeneratedModel in +/-- normSub is Nat subtraction. -/ +private theorem normSub_eq (a b : Nat) : normSub a b = a - b := rfl + +open Sqrt512GeneratedModel in +/-- normMul is Nat multiplication. -/ +private theorem normMul_eq (a b : Nat) : normMul a b = a * b := rfl + +open Sqrt512GeneratedModel in +/-- normNot 0 = 2^256 - 1. -/ +private theorem normNot_zero : normNot 0 = WORD_MOD - 1 := rfl + +open Sqrt512GeneratedModel in +/-- normClz for positive x < 2^256 gives 255 - log2 x. -/ +private theorem normClz_pos (x : Nat) (hx : 0 < x) : + normClz x = 255 - Nat.log2 x := by + unfold normClz; simp [Nat.ne_of_gt hx] + +open Sqrt512GeneratedModel in +/-- normLt is a 0/1 indicator. -/ +private theorem normLt_eq (a b : Nat) : normLt a b = if a < b then 1 else 0 := rfl + +open Sqrt512GeneratedModel in +open Sqrt512GeneratedModel in +/-- Norm bitwise: and(and(1, 255), 255) = 1. -/ +private theorem normAnd_1_255_255 : normAnd (normAnd 1 255) 255 = 1 := by decide + +open Sqrt512GeneratedModel in +/-- One Babylonian step in the norm model equals bstep. -/ +private theorem normStep_eq_bstep (x z : Nat) : + normShr (normAnd (normAnd 1 255) 255) (normAdd (normDiv x z) z) = bstep x z := by + simp [normAnd_1_255_255, normShr_eq, normAdd_eq, normDiv_eq, bstep, Nat.add_comm] + +open Sqrt512GeneratedModel in +/-- The generated model_bstep equals bstep (definitional). -/ +theorem model_bstep_eq_bstep (x z : Nat) : model_bstep x z = bstep x z := + normStep_eq_bstep x z + +open Sqrt512GeneratedModel in +/-- Floor correction: sub z (lt (div x z) z) gives the standard correction. -/ +private theorem normFloor_correction (x z : Nat) (hz : 0 < z) : + normSub z (normLt (normDiv x z) z) = + (if x / z < z then z - 1 else z) := by + simp only [normSub_eq, normLt_eq, normDiv_eq] + split <;> omega + +-- ============================================================================ +-- Section 5b: Constant-folding and bitwise helpers +-- ============================================================================ + +set_option maxRecDepth 4096 in +/-- For n < 256, n &&& 254 clears bit 0, giving 2*(n/2). -/ +private theorem and_254_eq : ∀ n : Fin 256, (n.val &&& 254) = 2 * (n.val / 2) := by + decide + +private theorem normAnd_shift_254 (n : Nat) (hn : n < 256) : + n &&& 254 = 2 * (n / 2) := + and_254_eq ⟨n, hn⟩ + +private theorem and_1_255 : (1 : Nat) &&& (255 : Nat) = 1 := by decide +private theorem and_128_255 : (128 : Nat) &&& (255 : Nat) = 128 := by decide + +/-- Bitwise OR equals addition when bits don't overlap. + Uses Nat.shiftLeft_add_eq_or_of_lt from Init. -/ +private theorem or_eq_add_shl (a b s : Nat) (hb : b < 2 ^ s) : + (a * 2 ^ s) ||| b = a * 2 ^ s + b := by + rw [← Nat.shiftLeft_eq] + exact (Nat.shiftLeft_add_eq_or_of_lt hb a).symm + +-- ============================================================================ +-- Section 6: Inner sqrt convergence (reusable for EVM bridge) +-- ============================================================================ + +-- The norm model's Babylonian steps (using unbounded normAdd) are identical to +-- bstep, and therefore converge to natSqrt on normalized inputs [2^254, 2^256). +-- For the EVM bridge, we reuse this by showing the EVM Babylonian steps produce +-- the same values (since the sums don't overflow 2^256). + +open Sqrt512GeneratedModel in +/-- The 6 model_bstep calls equal run6Fixed. -/ +private theorem norm_6steps_eq_run6Fixed (x_hi_1 : Nat) : + let r_hi_1 := FIXED_SEED + let r_hi_2 := model_bstep x_hi_1 r_hi_1 + let r_hi_3 := model_bstep x_hi_1 r_hi_2 + let r_hi_4 := model_bstep x_hi_1 r_hi_3 + let r_hi_5 := model_bstep x_hi_1 r_hi_4 + let r_hi_6 := model_bstep x_hi_1 r_hi_5 + let r_hi_7 := model_bstep x_hi_1 r_hi_6 + r_hi_7 = run6Fixed x_hi_1 := by + simp only [model_bstep_eq_bstep, run6Fixed, FIXED_SEED, bstep] + +open Sqrt512GeneratedModel in +/-- The 6 steps + floor correction in the norm model = floorSqrt_fixed. -/ +private theorem norm_inner_sqrt_eq_floorSqrt_fixed (x_hi_1 : Nat) (hx : 0 < x_hi_1) : + let r_hi_1 := FIXED_SEED + let r_hi_2 := model_bstep x_hi_1 r_hi_1 + let r_hi_3 := model_bstep x_hi_1 r_hi_2 + let r_hi_4 := model_bstep x_hi_1 r_hi_3 + let r_hi_5 := model_bstep x_hi_1 r_hi_4 + let r_hi_6 := model_bstep x_hi_1 r_hi_5 + let r_hi_7 := model_bstep x_hi_1 r_hi_6 + let r_hi_8 := normSub r_hi_7 (normLt (normDiv x_hi_1 r_hi_7) r_hi_7) + r_hi_8 = floorSqrt_fixed x_hi_1 := by + simp only [model_bstep_eq_bstep] + have h7 := norm_6steps_eq_run6Fixed x_hi_1 + simp only [model_bstep_eq_bstep] at h7 + have hz_pos : 0 < run6Fixed x_hi_1 := by + have hseed_pos : 0 < FIXED_SEED := fixed_seed_pos + have hz1_pos := bstep_pos x_hi_1 FIXED_SEED hx hseed_pos + have hz2_pos := bstep_pos x_hi_1 _ hx hz1_pos + have hz3_pos := bstep_pos x_hi_1 _ hx hz2_pos + have hz4_pos := bstep_pos x_hi_1 _ hx hz3_pos + have hz5_pos := bstep_pos x_hi_1 _ hx hz4_pos + have hz6_pos := bstep_pos x_hi_1 _ hx hz5_pos + exact hz6_pos + rw [h7, normFloor_correction x_hi_1 (run6Fixed x_hi_1) hz_pos] + unfold floorSqrt_fixed + simp [Nat.ne_of_gt hz_pos] + +open Sqrt512GeneratedModel in +/-- The norm inner sqrt gives natSqrt on normalized inputs. -/ +private theorem norm_inner_sqrt_eq_natSqrt (x_hi_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) : + let r_hi_1 := FIXED_SEED + let r_hi_2 := model_bstep x_hi_1 r_hi_1 + let r_hi_3 := model_bstep x_hi_1 r_hi_2 + let r_hi_4 := model_bstep x_hi_1 r_hi_3 + let r_hi_5 := model_bstep x_hi_1 r_hi_4 + let r_hi_6 := model_bstep x_hi_1 r_hi_5 + let r_hi_7 := model_bstep x_hi_1 r_hi_6 + let r_hi_8 := normSub r_hi_7 (normLt (normDiv x_hi_1 r_hi_7) r_hi_7) + r_hi_8 = natSqrt x_hi_1 := by + have hpos : 0 < x_hi_1 := by omega + have h := norm_inner_sqrt_eq_floorSqrt_fixed x_hi_1 hpos + simp only [model_bstep_eq_bstep] at h ⊢ + rw [h] + exact floorSqrt_fixed_eq_natSqrt x_hi_1 hlo hhi + +-- ============================================================================ +-- Section 7: EVM operation bridge lemmas +-- ============================================================================ + +section EvmNormBridge +open Sqrt512GeneratedModel + +theorem u256_id' (x : Nat) (hx : x < WORD_MOD) : u256 x = x := + Nat.mod_eq_of_lt hx + +theorem evmSub_eq_of_le (a b : Nat) (ha : a < WORD_MOD) (hb : b ≤ a) : + evmSub a b = a - b := by + have hb' : b < WORD_MOD := Nat.lt_of_le_of_lt hb ha + have hab' : a - b < WORD_MOD := Nat.lt_of_le_of_lt (Nat.sub_le a b) ha + unfold evmSub u256 + simp [Nat.mod_eq_of_lt ha, Nat.mod_eq_of_lt hb'] + have hsplit : a + WORD_MOD - b = WORD_MOD + (a - b) := by omega + rw [hsplit, Nat.add_mod, Nat.mod_eq_zero_of_dvd (Nat.dvd_refl WORD_MOD), Nat.zero_add, + Nat.mod_mod_of_dvd, Nat.mod_eq_of_lt hab'] + exact Nat.dvd_refl WORD_MOD + +theorem evmDiv_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : 0 < b) (hb' : b < WORD_MOD) : + evmDiv a b = a / b := by + unfold evmDiv + simp only [u256_id' a ha, u256_id' b hb'] + simp [Nat.ne_of_gt hb] + +theorem evmMod_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : 0 < b) (hb' : b < WORD_MOD) : + evmMod a b = a % b := by + unfold evmMod + simp only [u256_id' a ha, u256_id' b hb'] + simp [Nat.ne_of_gt hb] + +theorem evmOr_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) : + evmOr a b = a ||| b := by + unfold evmOr; simp [u256_id' a ha, u256_id' b hb] + +theorem evmAnd_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) : + evmAnd a b = a &&& b := by + unfold evmAnd; simp [u256_id' a ha, u256_id' b hb] + +theorem evmShr_eq' (s v : Nat) (hs : s < 256) (hv : v < WORD_MOD) : + evmShr s v = v / 2 ^ s := by + have hs' : s < WORD_MOD := by unfold WORD_MOD; omega + unfold evmShr; simp [u256_id' s hs', u256_id' v hv, hs] + +theorem evmShl_eq' (s v : Nat) (hs : s < 256) (hv : v < WORD_MOD) : + evmShl s v = (v * 2 ^ s) % WORD_MOD := by + have hs' : s < WORD_MOD := by unfold WORD_MOD; omega + unfold evmShl u256 + simp [Nat.mod_eq_of_lt hs', Nat.mod_eq_of_lt hv, hs] + +theorem evmAdd_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) + (hab : a + b < WORD_MOD) : + evmAdd a b = a + b := by + unfold evmAdd u256 + simp [Nat.mod_eq_of_lt ha, Nat.mod_eq_of_lt hb, Nat.mod_eq_of_lt hab] + +theorem evmMul_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) : + evmMul a b = (a * b) % WORD_MOD := by + unfold evmMul u256 + simp [Nat.mod_eq_of_lt ha, Nat.mod_eq_of_lt hb] + +theorem evmClz_eq' (v : Nat) (hv : v < WORD_MOD) : + evmClz v = if v = 0 then 256 else 255 - Nat.log2 v := by + unfold evmClz; simp [u256_id' v hv] + +theorem evmLt_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) : + evmLt a b = if a < b then 1 else 0 := by + unfold evmLt; simp [u256_id' a ha, u256_id' b hb] + +theorem evmEq_eq' (a b : Nat) (ha : a < WORD_MOD) (hb : b < WORD_MOD) : + evmEq a b = if a = b then 1 else 0 := by + unfold evmEq; simp [u256_id' a ha, u256_id' b hb] + +theorem evmNot_eq' (a : Nat) (ha : a < WORD_MOD) : + evmNot a = WORD_MOD - 1 - a := by + unfold evmNot; simp [u256_id' a ha] + +/-- When a + b = WORD_MOD and f ∈ {0,1}, EVM overflow+underflow gives the right answer. -/ +theorem evmSub_evmAdd_eq_of_overflow (a b : Nat) + (ha : a < WORD_MOD) (hb : b < WORD_MOD) + (hab : a + b = WORD_MOD) : + evmSub (evmAdd a b) 1 = WORD_MOD - 1 := by + unfold evmAdd evmSub u256 + simp [Nat.mod_eq_of_lt ha, Nat.mod_eq_of_lt hb, hab, Nat.mod_self] + have h1 : (1 : Nat) < WORD_MOD := by unfold WORD_MOD; omega + simp [Nat.mod_eq_of_lt h1] + +/-- Generic: (a * n) % (n * n) = (a % n) * n -/ +private theorem mul_mod_sq (a n : Nat) (hn : 0 < n) : + (a * n) % (n * n) = (a % n) * n := by + -- a = n * (a/n) + a%n, so a*n = n*n*(a/n) + (a%n)*n + have h := Nat.div_add_mod a n -- n * (a / n) + a % n = a + have ha : a * n = n * n * (a / n) + a % n * n := by + have h2 : a * n = (n * (a / n) + a % n) * n := by rw [h] + rw [h2, Nat.add_mul] + show n * (a / n) * n + a % n * n = n * n * (a / n) + a % n * n + congr 1 + rw [Nat.mul_assoc, Nat.mul_comm (a / n) n, ← Nat.mul_assoc] + rw [ha, Nat.mul_add_mod] + exact Nat.mod_eq_of_lt (Nat.mul_lt_mul_of_pos_right (Nat.mod_lt a hn) hn) + +/-- Key: (a * 2^128) % 2^256 = (a % 2^128) * 2^128 -/ +private theorem mul_pow128_mod_word (a : Nat) : + (a * 2 ^ 128) % WORD_MOD = (a % 2 ^ 128) * 2 ^ 128 := by + have : WORD_MOD = 2 ^ 128 * 2 ^ 128 := by unfold WORD_MOD; rw [← Nat.pow_add] + rw [this]; exact mul_mod_sq a (2 ^ 128) (Nat.two_pow_pos 128) + +/-- Euclidean division after recomposition: (d*q + r)/d = q + r/d -/ +private theorem div_of_mul_add (d q r : Nat) (hd : 0 < d) : + (d * q + r) / d = q + r / d := by + rw [show d * q + r = r + q * d from by rw [Nat.mul_comm, Nat.add_comm], + Nat.add_mul_div_right r q hd, Nat.add_comm] + +/-- Euclidean mod after recomposition: (d*q + r) % d = r % d -/ +private theorem mod_of_mul_add (d q r : Nat) : + (d * q + r) % d = r % d := by + rw [show d * q + r = r + q * d from by rw [Nat.mul_comm, Nat.add_comm]] + exact Nat.add_mul_mod_self_right r q d + +end EvmNormBridge + +-- ============================================================================ +-- Section 8: Sub-model bridge theorems +-- ============================================================================ + +-- With the refactored Solidity code, model_sqrt512_evm now calls three +-- sub-models: model_innerSqrt_evm, model_karatsubaQuotient_evm, and +-- model_sqrtCorrection_evm. Each sub-model is proved correct independently, +-- then chained in the top-level theorem. + +section SubModelBridge +open Sqrt512GeneratedModel + +/-- The norm model of _innerSqrt gives (floorSqrt_fixed x, x - floorSqrt_fixed(x)²). + Follows from norm_inner_sqrt_eq_floorSqrt_fixed by unfolding model_innerSqrt. -/ +private theorem model_innerSqrt_fst_eq_floorSqrt_fixed (x_hi_1 : Nat) (hx : 0 < x_hi_1) : + (model_innerSqrt x_hi_1).1 = floorSqrt_fixed x_hi_1 := by + unfold model_innerSqrt + exact norm_inner_sqrt_eq_floorSqrt_fixed x_hi_1 hx + +/-- The norm model of _innerSqrt gives natSqrt on normalized inputs. -/ +private theorem model_innerSqrt_fst_eq_natSqrt (x_hi_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) : + (model_innerSqrt x_hi_1).1 = natSqrt x_hi_1 := by + have hpos : 0 < x_hi_1 := by omega + rw [model_innerSqrt_fst_eq_floorSqrt_fixed x_hi_1 hpos] + exact floorSqrt_fixed_eq_natSqrt x_hi_1 hlo hhi + +end SubModelBridge + +-- ============================================================================ +-- Section 9: Direct EVM model → sqrt512 bridge +-- ============================================================================ + +-- We prove model_sqrt512_evm = sqrt512 DIRECTLY, without going through the +-- norm model (model_sqrt512). The norm model uses unbounded normShl/normMul +-- which don't match EVM uint256 semantics, making it unsuitable as an intermediate. +-- +-- The EVM model uses u256-wrapped operations that correctly implement the +-- Solidity algorithm. We show its output equals sqrt512(x_hi * 2^256 + x_lo). +-- +-- With the refactored model structure, the proof decomposes into sub-lemmas +-- that each unfold only ONE sub-model: +-- A. EVM normalization: x_hi_1 = x*4^k/2^256, x_lo_1 = x*4^k%2^256 +-- B. model_innerSqrt_evm → (natSqrt(x_hi_1), x_hi_1 - natSqrt(x_hi_1)²) +-- C. model_karatsubaQuotient_evm → quotient/remainder with carry correction +-- D. model_sqrtCorrection_evm → combine + 257-bit correction = karatsubaFloor +-- E. Chain: karatsubaFloor(x_hi_1, x_lo_1) / 2^k = sqrt512(x) + +section EvmBridge +open Sqrt512GeneratedModel + +/-- Sub-lemma A: The EVM normalization phase computes the correct normalized words. + Given x = x_hi * 2^256 + x_lo and k = (255 - log2 x_hi) / 2: + - x_hi_1 = (x * 4^k) / 2^256 + - x_lo_1 = (x * 4^k) % 2^256 + - shift_1 = k -/ +-- 512-bit left shift decomposition into high/low 256-bit words. +theorem shl512_hi (x_hi x_lo s : Nat) (hs : s ≤ 255) : + (x_hi * 2 ^ 256 + x_lo) * 2 ^ s / 2 ^ 256 = + x_hi * 2 ^ s + x_lo / 2 ^ (256 - s) := by + have hrw : (x_hi * 2 ^ 256 + x_lo) * 2 ^ s = + x_lo * 2 ^ s + x_hi * 2 ^ s * 2 ^ 256 := by + rw [Nat.add_mul, Nat.mul_right_comm]; omega + rw [hrw, Nat.add_mul_div_right _ _ (Nat.two_pow_pos 256), Nat.add_comm] + congr 1 + have h256_split : 2 ^ 256 = 2 ^ (256 - s) * 2 ^ s := by + rw [← Nat.pow_add]; congr 1; omega + rw [h256_split] + exact Nat.mul_div_mul_right _ _ (Nat.two_pow_pos s) + +theorem shl512_lo' (x_hi x_lo s : Nat) : + (x_hi * 2 ^ 256 + x_lo) * 2 ^ s % 2 ^ 256 = + (x_lo * 2 ^ s) % 2 ^ 256 := by + have hrw : (x_hi * 2 ^ 256 + x_lo) * 2 ^ s = + x_lo * 2 ^ s + x_hi * 2 ^ s * 2 ^ 256 := by + rw [Nat.add_mul, Nat.mul_right_comm]; omega + rw [hrw, Nat.add_mul_mod_self_right] + +-- x_hi * 2^s < 2^256 when x_hi * 2^s is exactly the product (no overflow) +-- and shift_range guarantees this. +private theorem shl_no_overflow (x_hi s : Nat) (h : x_hi * 2 ^ s < 2 ^ 256) : + (x_hi * 2 ^ s) % (2 ^ 256) = x_hi * 2 ^ s := + Nat.mod_eq_of_lt h + +-- The bottom s bits of (x_hi * 2^s) % 2^256 are zero, so OR = add with values < 2^s. +private theorem shl_or_shr (x_hi x_lo s : Nat) (hs : 0 < s) (hs' : s ≤ 255) + (hxlo : x_lo < 2 ^ 256) : + (x_hi * 2 ^ s) ||| (x_lo / 2 ^ (256 - s)) = + x_hi * 2 ^ s + x_lo / 2 ^ (256 - s) := by + have hcarry : x_lo / 2 ^ (256 - s) < 2 ^ s := by + rw [Nat.div_lt_iff_lt_mul (Nat.two_pow_pos _)] + calc x_lo < 2 ^ 256 := hxlo + _ = 2 ^ s * 2 ^ (256 - s) := by rw [← Nat.pow_add]; congr 1; omega + -- x_hi * 2^s is a multiple of 2^s, carry < 2^s, so bits don't overlap + exact or_eq_add_shl x_hi (x_lo / 2 ^ (256 - s)) s hcarry + +-- Full high word computation: OR of SHL and SHR equals the high word of the 512-bit shift. +private theorem shl512_hi_or (x_hi x_lo s : Nat) (hs : 0 < s) (hs' : s ≤ 255) + (hxhi_shl : x_hi * 2 ^ s < 2 ^ 256) (hxlo : x_lo < 2 ^ 256) : + ((x_hi * 2 ^ s) % 2 ^ 256) ||| (x_lo / 2 ^ (256 - s)) = + (x_hi * 2 ^ 256 + x_lo) * 2 ^ s / 2 ^ 256 := by + rw [shl_no_overflow x_hi s hxhi_shl, shl_or_shr x_hi x_lo s hs hs' hxlo, + shl512_hi x_hi x_lo s hs'] + +private theorem evm_normalization_correct (x_hi x_lo : Nat) + (hxhi_pos : 0 < x_hi) (hxhi_lt : x_hi < 2 ^ 256) (hxlo_lt : x_lo < 2 ^ 256) : + let x := x_hi * 2 ^ 256 + x_lo + let k := (255 - Nat.log2 x_hi) / 2 + let shift := evmClz (u256 x_hi) + let dbl_k := evmAnd shift 254 + let x_lo_1 := evmShl dbl_k (u256 x_lo) + let x_hi_1 := evmOr (evmShl dbl_k (u256 x_hi)) (evmShr (evmSub 256 dbl_k) (u256 x_lo)) + let shift_1 := evmShr (evmAnd (evmAnd 1 255) 255) shift + x_hi_1 = x * 4 ^ k / 2 ^ 256 ∧ + x_lo_1 = x * 4 ^ k % 2 ^ 256 ∧ + shift_1 = k ∧ + 2 ^ 254 ≤ x_hi_1 ∧ + x_hi_1 < 2 ^ 256 ∧ + x_lo_1 < 2 ^ 256 := by + -- Inline all let-bindings upfront so rw/simp work on concrete expressions. + -- dbl_k = evmAnd (evmClz (u256 x_hi)) 254 + -- shift = evmClz (u256 x_hi) + -- x = x_hi * 2^256 + x_lo, k = (255 - log2 x_hi) / 2 + show let x := x_hi * 2 ^ 256 + x_lo; let k := (255 - Nat.log2 x_hi) / 2; + let dbl_k := evmAnd (evmClz (u256 x_hi)) 254; + evmOr (evmShl dbl_k (u256 x_hi)) (evmShr (evmSub 256 dbl_k) (u256 x_lo)) = + x * 4 ^ k / 2 ^ 256 ∧ + evmShl dbl_k (u256 x_lo) = x * 4 ^ k % 2 ^ 256 ∧ + evmShr (evmAnd (evmAnd 1 255) 255) (evmClz (u256 x_hi)) = k ∧ + 2 ^ 254 ≤ evmOr (evmShl dbl_k (u256 x_hi)) (evmShr (evmSub 256 dbl_k) (u256 x_lo)) ∧ + evmOr (evmShl dbl_k (u256 x_hi)) (evmShr (evmSub 256 dbl_k) (u256 x_lo)) < 2 ^ 256 ∧ + evmShl dbl_k (u256 x_lo) < 2 ^ 256 + intro x; intro k; intro dbl_k + have hxhi_wm : x_hi < WORD_MOD := hxhi_lt + have hxlo_wm : x_lo < WORD_MOD := hxlo_lt + have hxhi_ne : x_hi ≠ 0 := Nat.ne_of_gt hxhi_pos + have hlog_le : Nat.log2 x_hi ≤ 255 := by + have := (Nat.log2_lt hxhi_ne).2 hxhi_lt; omega + -- Step 1: evmClz (u256 x_hi) = 255 - log2(x_hi) + have hshift_eq : evmClz (u256 x_hi) = 255 - Nat.log2 x_hi := by + rw [u256_id' x_hi hxhi_wm, evmClz_eq' x_hi hxhi_wm]; simp [hxhi_ne] + have hshift_wm : evmClz (u256 x_hi) < WORD_MOD := by rw [hshift_eq]; unfold WORD_MOD; omega + -- Step 2: dbl_k = 2 * k + have hdbl_k : dbl_k = 2 * k := by + show evmAnd (evmClz (u256 x_hi)) 254 = _ + rw [evmAnd_eq' _ 254 hshift_wm (by unfold WORD_MOD; omega), hshift_eq] + exact normAnd_shift_254 (255 - Nat.log2 x_hi) (by omega) + have hdbl_k_lt : dbl_k < 256 := by omega + have hdbl_k_le : dbl_k ≤ 254 := by omega + -- Step 3: shift_1 = k + have hshift_1_eq : evmShr (evmAnd (evmAnd 1 255) 255) (evmClz (u256 x_hi)) = k := by + have h1 : (1 : Nat) < WORD_MOD := by unfold WORD_MOD; omega + have h255 : (255 : Nat) < WORD_MOD := by unfold WORD_MOD; omega + rw [evmAnd_eq' 1 255 h1 h255, and_1_255, evmAnd_eq' 1 255 h1 h255, and_1_255] + rw [evmShr_eq' 1 _ (by omega) hshift_wm, hshift_eq, Nat.pow_one] + -- Step 4: 4^k = 2^dbl_k + have hfour_eq : 4 ^ k = 2 ^ dbl_k := by + rw [hdbl_k, show (4 : Nat) = 2 ^ 2 from by decide, ← Nat.pow_mul] + -- Step 5: x_hi * 2^dbl_k < 2^256 + have hsr := shift_range x_hi hxhi_pos hxhi_lt + have hxhi_shl_lt : x_hi * 2 ^ dbl_k < 2 ^ 256 := by rw [← hfour_eq]; exact hsr.2 + -- Step 6: Simplify EVM operations + have hsub_eq : evmSub 256 dbl_k = 256 - dbl_k := + evmSub_eq_of_le 256 dbl_k (by unfold WORD_MOD; omega) (by omega) + have hshl_xhi : evmShl dbl_k (u256 x_hi) = (x_hi * 2 ^ dbl_k) % WORD_MOD := by + rw [u256_id' x_hi hxhi_wm]; exact evmShl_eq' dbl_k x_hi hdbl_k_lt hxhi_wm + -- Steps 6-10 use complex EVM-to-Nat rewrites. We case-split on dbl_k = 0 + -- (which is the only case where 256 - dbl_k = 256 makes evmShr behave differently). + by_cases hdbl_k_zero : dbl_k = 0 + · -- CASE: dbl_k = 0, so k = 0, x already normalized + have hk_zero : k = 0 := by omega + -- With dbl_k = 0 and k = 0: 4^k = 1, x * 1 = x, x/2^256 = x_hi, x%2^256 = x_lo + -- Since dbl_k is a let-binding, we can't rw it directly. Use simp + show. + have hk_zero : k = 0 := by omega + -- Simplify all EVM ops with dbl_k = 0 + have hu256_xhi : u256 x_hi = x_hi := u256_id' x_hi hxhi_wm + have hu256_xlo : u256 x_lo = x_lo := u256_id' x_lo hxlo_wm + have hxhi1_eq : evmOr (evmShl dbl_k (u256 x_hi)) (evmShr (evmSub 256 dbl_k) (u256 x_lo)) = x_hi := by + rw [hdbl_k_zero, hu256_xhi, hu256_xlo] + rw [evmShl_eq' 0 x_hi (by omega) hxhi_wm, Nat.pow_zero, Nat.mul_one] + unfold WORD_MOD + rw [Nat.mod_eq_of_lt hxhi_lt] + -- evmShr (256 - 0) x_lo = evmShr 256 x_lo = 0 (since 256 is not < 256) + rw [evmSub_eq_of_le 256 0 (by unfold WORD_MOD; omega) (by omega)] + -- Goal: evmOr x_hi (evmShr 256 x_lo) = x_hi + -- evmShr 256 x_lo: shift 256 ≥ 256 so result is 0 + have : evmShr 256 x_lo = 0 := by + unfold evmShr u256 WORD_MOD; simp + rw [this] + rw [evmOr_eq' x_hi 0 hxhi_wm (by unfold WORD_MOD; omega)] + simp + have hxlo1_eq : evmShl dbl_k (u256 x_lo) = x_lo := by + rw [hdbl_k_zero, hu256_xlo, + evmShl_eq' 0 x_lo (by omega) hxlo_wm, Nat.pow_zero, Nat.mul_one] + unfold WORD_MOD; exact Nat.mod_eq_of_lt hxlo_lt + have hxdiv : x / 2 ^ 256 = x_hi := by + show (x_hi * 2 ^ 256 + x_lo) / 2 ^ 256 = x_hi + rw [Nat.mul_comm, Nat.mul_add_div (Nat.two_pow_pos 256), Nat.div_eq_of_lt hxlo_lt, + Nat.add_zero] + have hxmod : x % 2 ^ 256 = x_lo := by + show (x_hi * 2 ^ 256 + x_lo) % 2 ^ 256 = x_lo + rw [Nat.mul_comm, Nat.mul_add_mod]; exact Nat.mod_eq_of_lt hxlo_lt + -- 4^k = 4^0 = 1 when k = 0 + have h4k_one : 4 ^ k = 1 := by simp [hk_zero] + refine ⟨?_, ?_, hshift_1_eq, ?_, ?_, ?_⟩ + · rw [hxhi1_eq, h4k_one, Nat.mul_one, hxdiv] + · rw [hxlo1_eq, h4k_one, Nat.mul_one, hxmod] + · rw [hxhi1_eq]; have := hsr.1; rw [h4k_one, Nat.mul_one] at this; exact this + · rw [hxhi1_eq]; exact hxhi_lt + · rw [hxlo1_eq]; exact hxlo_lt + · -- CASE: dbl_k > 0, so 256 - dbl_k < 256 and evmShr works normally + have hdbl_k_pos : 0 < dbl_k := by omega + have hshr_xlo : evmShr (evmSub 256 dbl_k) (u256 x_lo) = x_lo / 2 ^ (256 - dbl_k) := by + rw [u256_id' x_lo hxlo_wm, hsub_eq] + exact evmShr_eq' (256 - dbl_k) x_lo (by omega) hxlo_wm + have hshl_xlo : evmShl dbl_k (u256 x_lo) = (x_lo * 2 ^ dbl_k) % WORD_MOD := by + rw [u256_id' x_lo hxlo_wm]; exact evmShl_eq' dbl_k x_lo hdbl_k_lt hxlo_wm + have hshl_xhi_wm : evmShl dbl_k (u256 x_hi) < WORD_MOD := by + rw [hshl_xhi]; exact Nat.mod_lt _ (by unfold WORD_MOD; omega) + have hshr_xlo_wm : evmShr (evmSub 256 dbl_k) (u256 x_lo) < WORD_MOD := by + rw [hshr_xlo]; unfold WORD_MOD + have : x_lo / 2 ^ (256 - dbl_k) < 2 ^ dbl_k := by + rw [Nat.div_lt_iff_lt_mul (Nat.two_pow_pos _)] + calc x_lo < 2 ^ 256 := hxlo_lt + _ = 2 ^ dbl_k * 2 ^ (256 - dbl_k) := by rw [← Nat.pow_add]; congr 1; omega + exact Nat.lt_of_lt_of_le this (Nat.pow_le_pow_right (by omega) (by omega)) + have hxhi1_eq : evmOr (evmShl dbl_k (u256 x_hi)) (evmShr (evmSub 256 dbl_k) (u256 x_lo)) = + x * 4 ^ k / 2 ^ 256 := by + rw [evmOr_eq' _ _ hshl_xhi_wm hshr_xlo_wm, hshl_xhi, hshr_xlo] + unfold WORD_MOD + rw [shl512_hi_or x_hi x_lo dbl_k hdbl_k_pos (by omega) hxhi_shl_lt hxlo_lt] + congr 1; rw [← hfour_eq] + have hxlo1_eq : evmShl dbl_k (u256 x_lo) = x * 4 ^ k % 2 ^ 256 := by + rw [hshl_xlo]; unfold WORD_MOD + rw [show x * 4 ^ k = (x_hi * 2 ^ 256 + x_lo) * 2 ^ dbl_k from by rw [← hfour_eq]] + exact (shl512_lo' x_hi x_lo dbl_k).symm + have hhi_eq : x * 4 ^ k / 2 ^ 256 = x_hi * 2 ^ dbl_k + x_lo / 2 ^ (256 - dbl_k) := by + rw [show x * 4 ^ k = (x_hi * 2 ^ 256 + x_lo) * 2 ^ dbl_k from by rw [← hfour_eq]] + exact shl512_hi x_hi x_lo dbl_k (by omega) + have hhi_lo_bound : 2 ^ 254 ≤ x * 4 ^ k / 2 ^ 256 := by + rw [hhi_eq]; have := hsr.1; rw [hfour_eq] at this; omega + have hshr_xlo_val : x_lo / 2 ^ (256 - dbl_k) < 2 ^ dbl_k := by + rw [Nat.div_lt_iff_lt_mul (Nat.two_pow_pos _)] + calc x_lo < 2 ^ 256 := hxlo_lt + _ = 2 ^ dbl_k * 2 ^ (256 - dbl_k) := by rw [← Nat.pow_add]; congr 1; omega + have hhi_hi_bound : x * 4 ^ k / 2 ^ 256 < 2 ^ 256 := by + rw [hhi_eq] + have h2 : (x_hi + 1) * 2 ^ dbl_k ≤ 2 ^ 256 := by + rw [Nat.succ_mul] + have h256 : 2 ^ 256 = 2 ^ dbl_k * 2 ^ (256 - dbl_k) := by + rw [← Nat.pow_add]; congr 1; omega + -- hxhi_shl_lt : x_hi * 2^dbl_k < 2^256 + -- Goal: x_hi * 2^dbl_k + 2^dbl_k ≤ 2^256 + -- From x_hi * 2^dbl_k < 2^256 = 2^dbl_k * 2^(256-dbl_k) + -- we get x_hi < 2^(256-dbl_k), so (x_hi+1) * 2^dbl_k ≤ 2^(256-dbl_k) * 2^dbl_k = 2^256 + rw [h256] at hxhi_shl_lt ⊢ + have hxhi_lt_pow : x_hi < 2 ^ (256 - dbl_k) := by + rw [Nat.mul_comm] at hxhi_shl_lt + exact Nat.lt_of_mul_lt_mul_left hxhi_shl_lt + calc x_hi * 2 ^ dbl_k + 2 ^ dbl_k + = (x_hi + 1) * 2 ^ dbl_k := by rw [Nat.succ_mul] + _ ≤ 2 ^ (256 - dbl_k) * 2 ^ dbl_k := + Nat.mul_le_mul_right _ hxhi_lt_pow + _ = 2 ^ dbl_k * 2 ^ (256 - dbl_k) := Nat.mul_comm _ _ + calc x_hi * 2 ^ dbl_k + x_lo / 2 ^ (256 - dbl_k) + < x_hi * 2 ^ dbl_k + 2 ^ dbl_k := by omega + _ = (x_hi + 1) * 2 ^ dbl_k := by rw [Nat.succ_mul] + _ ≤ 2 ^ 256 := h2 + have hlo1_bound : evmShl dbl_k (u256 x_lo) < 2 ^ 256 := by + rw [hxlo1_eq]; exact Nat.mod_lt _ (by omega) + exact ⟨hxhi1_eq, hxlo1_eq, hshift_1_eq, + hxhi1_eq ▸ hhi_lo_bound, hxhi1_eq ▸ hhi_hi_bound, hlo1_bound⟩ + +/-- One EVM Babylonian step equals bstep when z ≥ 2^127, z < 2^129, x ∈ [2^254, 2^256). + The sum z + x/z < 2^129 + 2^129 = 2^130 < 2^256 so evmAdd doesn't overflow. + Also preserves the bound: 2^127 ≤ bstep x z < 2^129. -/ +private theorem evm_bstep_eq (x z : Nat) + (hx_lo : 2 ^ 254 ≤ x) (hx_hi : x < WORD_MOD) + (hz_lo : 2 ^ 127 ≤ z) (hz_hi : z < 2 ^ 129) : + evmShr 1 (evmAdd z (evmDiv x z)) = bstep x z ∧ + 2 ^ 127 ≤ bstep x z ∧ bstep x z < 2 ^ 129 := by + have hz_pos : 0 < z := by omega + have hz_wm : z < WORD_MOD := by unfold WORD_MOD; omega + -- x / z < 2^129 since x < 2^256 and z ≥ 2^127 + have hxz_bound : x / z < 2 ^ 129 := by + rw [Nat.div_lt_iff_lt_mul hz_pos] + calc x < WORD_MOD := hx_hi + _ = 2 ^ 256 := rfl + _ = 2 ^ 129 * 2 ^ 127 := by rw [← Nat.pow_add] + _ ≤ 2 ^ 129 * z := Nat.mul_le_mul_left _ hz_lo + have hxz_lt : x / z < WORD_MOD := by unfold WORD_MOD; omega + -- The sum z + x/z < 2^129 + 2^129 = 2^130 < WORD_MOD + have hsum : z + x / z < WORD_MOD := by + have h3 : (2 : Nat) ^ 129 + 2 ^ 129 ≤ WORD_MOD := by unfold WORD_MOD; omega + omega + -- Simplify evmDiv first, then evmAdd, then evmShr + have hdiv_eq : evmDiv x z = x / z := evmDiv_eq' x z hx_hi hz_pos hz_wm + have hadd_eq : evmAdd z (evmDiv x z) = z + x / z := by + rw [hdiv_eq]; exact evmAdd_eq' z (x / z) hz_wm hxz_lt hsum + have hadd_bound : evmAdd z (evmDiv x z) < WORD_MOD := by + rw [hadd_eq]; exact hsum + have hstep_val : evmShr 1 (evmAdd z (evmDiv x z)) = (z + x / z) / 2 := by + have h := evmShr_eq' 1 _ (by omega : (1 : Nat) < 256) hadd_bound + rw [h, hadd_eq, Nat.pow_one] + have hbstep : bstep x z = (z + x / z) / 2 := rfl + constructor + · rw [hstep_val, hbstep] + constructor + -- Lower bound: bstep x z ≥ 2^127 + -- Uses babylon_step_floor_bound with m = 2^127: if m^2 ≤ x then m ≤ bstep x z + · have hmsq : (2 : Nat) ^ 127 * 2 ^ 127 ≤ x := by + have : (2 : Nat) ^ 127 * 2 ^ 127 = 2 ^ 254 := by rw [← Nat.pow_add] + omega + exact babylon_step_floor_bound x z (2 ^ 127) hz_pos hmsq + -- Upper bound: bstep x z < 2^129 + · rw [hbstep] + have hsum_bound : z + x / z < 2 ^ 129 + 2 ^ 129 := by omega + -- (a / 2 < b) when (a < 2 * b) + omega + +/-- EVM bitwise: and(and(1, 255), 255) = 1. -/ +private theorem evmAnd_1_255_255 : evmAnd (evmAnd 1 255) 255 = 1 := by decide + +/-- EVM addition is commutative. -/ +private theorem evmAdd_comm (a b : Nat) : evmAdd a b = evmAdd b a := by + unfold evmAdd u256; rw [Nat.add_comm (a % WORD_MOD) (b % WORD_MOD)] + +/-- The generated model_bstep_evm = bstep when x ∈ [2^254, 2^256) and z ∈ [2^127, 2^129). + Wraps evm_bstep_eq by stripping the u256 wrappers. Also preserves bounds. -/ +private theorem model_bstep_evm_eq_bstep (x z : Nat) + (hx_lo : 2 ^ 254 ≤ x) (hx_hi : x < WORD_MOD) + (hz_lo : 2 ^ 127 ≤ z) (hz_hi : z < 2 ^ 129) : + model_bstep_evm x z = bstep x z ∧ + 2 ^ 127 ≤ bstep x z ∧ bstep x z < 2 ^ 129 := by + have hx_wm : x < WORD_MOD := hx_hi + have hz_wm : z < WORD_MOD := by unfold WORD_MOD; omega + unfold model_bstep_evm + simp only [u256_id' x hx_wm, u256_id' z hz_wm] + rw [evmAnd_1_255_255, evmAdd_comm (evmDiv x z) z] + exact evm_bstep_eq x z hx_lo hx_hi hz_lo hz_hi + +/-- FIXED_SEED < 2^128 < 2^129. -/ +private theorem fixed_seed_lt_2_129 : FIXED_SEED < 2 ^ 129 := by + unfold FIXED_SEED; omega + +/-- FIXED_SEED ≥ 2^127. -/ +private theorem fixed_seed_ge_2_127 : 2 ^ 127 ≤ FIXED_SEED := by + unfold FIXED_SEED; omega + +/-- Sub-lemma B: model_innerSqrt_evm gives (natSqrt(x_hi_1), x_hi_1 - natSqrt(x_hi_1)²). + Unfolds only model_innerSqrt_evm (~10 let-bindings). Each EVM Babylonian step + equals bstep (proved by evm_bstep_eq), and the floor correction matches on + bounded inputs. Together: EVM inner sqrt = floorSqrt_fixed = natSqrt. -/ +private theorem natSqrt_lt_2_128 (x : Nat) (hx : x < 2 ^ 256) : + natSqrt x < 2 ^ 128 := by + suffices h : ¬(2 ^ 128 ≤ natSqrt x) by omega + intro h + have hsq := natSqrt_sq_le x + have hpow : (2 : Nat) ^ 128 * 2 ^ 128 = 2 ^ 256 := by rw [← Nat.pow_add] + have := Nat.mul_le_mul h h + omega + +private theorem natSqrt_ge_2_127 (x : Nat) (hx : 2 ^ 254 ≤ x) : + 2 ^ 127 ≤ natSqrt x := by + suffices h : ¬(natSqrt x < 2 ^ 127) by omega + intro h + have h1 : natSqrt x + 1 ≤ 2 ^ 127 := h + have h2 := Nat.mul_le_mul h1 h1 + have h3 := natSqrt_lt_succ_sq x + have h4 : (2 : Nat) ^ 127 * 2 ^ 127 = 2 ^ 254 := by rw [← Nat.pow_add] + omega + +/-- The norm model's second component is x - fst^2 (definitional). -/ +theorem model_innerSqrt_snd_def (x : Nat) : + (model_innerSqrt x).2 = x - (model_innerSqrt x).1 * (model_innerSqrt x).1 := by + rfl + +/-- The norm model's second component gives the residue x - natSqrt(x)^2. -/ +theorem model_innerSqrt_snd_eq_residue (x : Nat) + (hlo : 2 ^ 254 ≤ x) (hhi : x < 2 ^ 256) : + (model_innerSqrt x).2 = x - natSqrt x * natSqrt x := by + rw [model_innerSqrt_snd_def, model_innerSqrt_fst_eq_natSqrt x hlo hhi] + +/-- EVM inner sqrt computes (natSqrt x, x - natSqrt(x)²) on in-range inputs. + Each EVM Babylonian step equals bstep (since z + x/z < 2^130 < 2^256), + and each step stays in [2^127, 2^129). After 6 steps + correction, the + result matches natSqrt. The residue follows from evmMul/evmSub under bounds. + The bstep chain and correction logic are established once and shared across + both components of the pair. -/ +private theorem evm_innerSqrt_pair (x_hi_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) : + (model_innerSqrt_evm x_hi_1).1 = natSqrt x_hi_1 ∧ + (model_innerSqrt_evm x_hi_1).2 = x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1 := by + have hx_wm : x_hi_1 < WORD_MOD := hhi + -- ===== Common setup: bstep chain (established ONCE for both components) ===== + have h1 := model_bstep_evm_eq_bstep x_hi_1 FIXED_SEED hlo hx_wm + fixed_seed_ge_2_127 fixed_seed_lt_2_129 + have h2 := model_bstep_evm_eq_bstep x_hi_1 _ hlo hx_wm h1.2.1 h1.2.2 + have h3 := model_bstep_evm_eq_bstep x_hi_1 _ hlo hx_wm h2.2.1 h2.2.2 + have h4 := model_bstep_evm_eq_bstep x_hi_1 _ hlo hx_wm h3.2.1 h3.2.2 + have h5 := model_bstep_evm_eq_bstep x_hi_1 _ hlo hx_wm h4.2.1 h4.2.2 + have h6 := model_bstep_evm_eq_bstep x_hi_1 _ hlo hx_wm h5.2.1 h5.2.2 + -- Fold 6 bsteps to run6Fixed + have hz6_def : bstep x_hi_1 (bstep x_hi_1 (bstep x_hi_1 (bstep x_hi_1 + (bstep x_hi_1 (bstep x_hi_1 FIXED_SEED))))) = run6Fixed x_hi_1 := by + simp only [run6Fixed, FIXED_SEED, bstep] + -- ===== Common bounds on z6 := run6Fixed x_hi_1 ===== + have hz6_lo : 2 ^ 127 ≤ run6Fixed x_hi_1 := h6.2.1 + have hz6_wm : run6Fixed x_hi_1 < WORD_MOD := by unfold WORD_MOD; omega + have hz6_pos : 0 < run6Fixed x_hi_1 := by omega + -- ===== Common correction proof ===== + have hdiv_eq : evmDiv x_hi_1 (run6Fixed x_hi_1) = x_hi_1 / (run6Fixed x_hi_1) := + evmDiv_eq' x_hi_1 _ hx_wm hz6_pos hz6_wm + have hdiv_wm : x_hi_1 / (run6Fixed x_hi_1) < WORD_MOD := by + unfold WORD_MOD; exact Nat.lt_of_lt_of_le (by + rw [Nat.div_lt_iff_lt_mul hz6_pos] + calc x_hi_1 < 2 ^ 256 := hhi + _ = 2 ^ 129 * 2 ^ 127 := by rw [← Nat.pow_add] + _ ≤ 2 ^ 129 * run6Fixed x_hi_1 := Nat.mul_le_mul_left _ hz6_lo) + (by omega) + have hlt_eq : evmLt (evmDiv x_hi_1 (run6Fixed x_hi_1)) (run6Fixed x_hi_1) = + if x_hi_1 / (run6Fixed x_hi_1) < (run6Fixed x_hi_1) then 1 else 0 := by + rw [hdiv_eq]; exact evmLt_eq' _ _ hdiv_wm hz6_wm + have hlt_le : (if x_hi_1 / (run6Fixed x_hi_1) < (run6Fixed x_hi_1) then 1 + else (0 : Nat)) ≤ run6Fixed x_hi_1 := by split <;> omega + have hsub_corr : evmSub (run6Fixed x_hi_1) (evmLt (evmDiv x_hi_1 (run6Fixed x_hi_1)) + (run6Fixed x_hi_1)) = + (run6Fixed x_hi_1) - (if x_hi_1 / (run6Fixed x_hi_1) < (run6Fixed x_hi_1) + then 1 else 0) := by + rw [hlt_eq]; exact evmSub_eq_of_le _ _ hz6_wm hlt_le + have hbracket := fixed_seed_bracket x_hi_1 hlo hhi + have hcorr_eq : (run6Fixed x_hi_1) - (if x_hi_1 / (run6Fixed x_hi_1) < (run6Fixed x_hi_1) + then 1 else 0) = natSqrt x_hi_1 := by + simp only [Nat.div_lt_iff_lt_mul hz6_pos] + have hcc := correction_correct x_hi_1 (run6Fixed x_hi_1) hbracket.1 hbracket.2 + by_cases hlt : x_hi_1 < run6Fixed x_hi_1 * run6Fixed x_hi_1 + · simp [hlt] at hcc ⊢; omega + · simp [hlt] at hcc ⊢; omega + -- ===== Common natSqrt bounds (for residue) ===== + have hr8 := natSqrt_lt_2_128 x_hi_1 hhi + have hr8_sq_lt : natSqrt x_hi_1 * natSqrt x_hi_1 < WORD_MOD := by + calc natSqrt x_hi_1 * natSqrt x_hi_1 + < 2 ^ 128 * 2 ^ 128 := Nat.mul_lt_mul_of_le_of_lt (Nat.le_of_lt hr8) hr8 (by omega) + _ = WORD_MOD := by unfold WORD_MOD; rw [← Nat.pow_add] + have hr8_sq_le : natSqrt x_hi_1 * natSqrt x_hi_1 ≤ x_hi_1 := natSqrt_sq_le x_hi_1 + have hr8_wm : natSqrt x_hi_1 < WORD_MOD := by unfold WORD_MOD; omega + -- ===== Prove both components using the shared facts ===== + constructor + · -- .1: the corrected sqrt = natSqrt + unfold model_innerSqrt_evm + simp only [u256_id' x_hi_1 hx_wm, + show (240615969168004511545033772477625056927 : Nat) = FIXED_SEED from rfl, + h1.1, h2.1, h3.1, h4.1, h5.1, h6.1] + rw [hz6_def, hsub_corr, hcorr_eq] + · -- .2: the residue = x - natSqrt(x)² + unfold model_innerSqrt_evm + simp only [u256_id' x_hi_1 hx_wm, + show (240615969168004511545033772477625056927 : Nat) = FIXED_SEED from rfl, + h1.1, h2.1, h3.1, h4.1, h5.1, h6.1] + rw [hz6_def, hsub_corr, hcorr_eq, + evmMul_eq' (natSqrt x_hi_1) (natSqrt x_hi_1) hr8_wm hr8_wm, + Nat.mod_eq_of_lt hr8_sq_lt, + evmSub_eq_of_le x_hi_1 _ hx_wm hr8_sq_le] + +/-- EVM inner sqrt equals norm inner sqrt on in-range inputs. -/ +theorem model_innerSqrt_evm_eq_norm (x_hi_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) : + model_innerSqrt_evm x_hi_1 = model_innerSqrt x_hi_1 := by + have h := evm_innerSqrt_pair x_hi_1 hlo hhi + ext + · rw [h.1]; exact (model_innerSqrt_fst_eq_natSqrt x_hi_1 hlo hhi).symm + · rw [h.2]; exact (model_innerSqrt_snd_eq_residue x_hi_1 hlo hhi).symm + +theorem model_innerSqrt_evm_correct (x_hi_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) : + (model_innerSqrt_evm x_hi_1).1 = natSqrt x_hi_1 ∧ + (model_innerSqrt_evm x_hi_1).2 = x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1 := + evm_innerSqrt_pair x_hi_1 hlo hhi + +/-- Sub-lemma C: model_karatsubaQuotient_evm computes the Karatsuba quotient and + remainder, including carry correction for the 257-bit overflow case. + Unfolds only model_karatsubaQuotient_evm (~6 let-bindings + if-block). -/ +private theorem model_karatsubaQuotient_evm_correct + (res x_lo r_hi : Nat) + (hres : res ≤ 2 * r_hi) + (hxlo : x_lo < 2 ^ 256) (hrhi_lo : 2 ^ 127 ≤ r_hi) (hrhi_hi : r_hi < 2 ^ 128) + (hres_lt : res < 2 ^ 256) : + let n_full := res * 2 ^ 128 + x_lo / 2 ^ 128 + let d := 2 * r_hi + (model_karatsubaQuotient_evm res x_lo r_hi).1 = n_full / d % 2 ^ 256 ∧ + (model_karatsubaQuotient_evm res x_lo r_hi).2 = n_full % d % 2 ^ 256 := by + intro n_full d + -- === Key bounds === + have hres_wm : res < WORD_MOD := hres_lt + have hxlo_wm : x_lo < WORD_MOD := hxlo + have hrhi_wm : r_hi < WORD_MOD := by unfold WORD_MOD; omega + have hd_pos : (0 : Nat) < d := by show 0 < 2 * r_hi; omega + have hd_ge : (2 : Nat) ^ 128 ≤ d := by show 2 ^ 128 ≤ 2 * r_hi; omega + have hd_wm : d < WORD_MOD := by unfold WORD_MOD; omega + have h_wm_sq : WORD_MOD = 2 ^ 128 * 2 ^ 128 := by unfold WORD_MOD; rw [← Nat.pow_add] + have hxlo_hi : x_lo / 2 ^ 128 < 2 ^ 128 := + Nat.div_lt_of_lt_mul (by rw [← Nat.pow_add]; exact hxlo) + have hn_evm_lt : (res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128 < WORD_MOD := by + have := Nat.mod_lt res (Nat.two_pow_pos 128); rw [h_wm_sq]; omega + -- === EVM simplification lemmas === + have hd_eq : evmShl 1 r_hi = d := by + rw [evmShl_eq' 1 r_hi (by omega) hrhi_wm, Nat.pow_one, Nat.mul_comm] + exact Nat.mod_eq_of_lt (by unfold WORD_MOD; omega) + have hshl_res : evmShl 128 res = (res % 2 ^ 128) * 2 ^ 128 := by + rw [evmShl_eq' 128 res (by omega) hres_wm]; exact mul_pow128_mod_word res + have hshr_xlo : evmShr 128 x_lo = x_lo / 2 ^ 128 := + evmShr_eq' 128 x_lo (by omega) hxlo_wm + have hshl_wm : (res % 2 ^ 128) * 2 ^ 128 < WORD_MOD := by + have := Nat.mod_lt res (Nat.two_pow_pos 128); rw [h_wm_sq] + exact Nat.mul_lt_mul_of_pos_right this (Nat.two_pow_pos 128) + have hshr_wm : x_lo / 2 ^ 128 < WORD_MOD := by unfold WORD_MOD; omega + have hn_eq : evmOr (evmShl 128 res) (evmShr 128 x_lo) = + (res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128 := by + rw [hshl_res, hshr_xlo, evmOr_eq' _ _ hshl_wm hshr_wm] + exact or_eq_add_shl (res % 2 ^ 128) (x_lo / 2 ^ 128) 128 hxlo_hi + have hc_eq : evmShr 128 res = res / 2 ^ 128 := + evmShr_eq' 128 res (by omega) hres_wm + -- === Unfold model, inline let-bindings, then simplify EVM ops === + unfold model_karatsubaQuotient_evm + -- Step 1: Inline all let-bindings to make the goal flat + dsimp only + -- Step 2: Remove u256 wrappers and simplify EVM operations + simp only [u256_id' res hres_wm, u256_id' x_lo hxlo_wm, u256_id' r_hi hrhi_wm, + hshl_res, hshr_xlo, hd_eq, hc_eq] + -- The goal is now flat with an if on (res / 2^128 ≠ 0) + split + · -- CARRY case: res / 2^128 ≠ 0 + next hc_ne => + -- Simplify evmOr to n_evm (the EVM-computed n, missing one WORD_MOD from n_full) + have hn_or : evmOr (res % 2 ^ 128 * 2 ^ 128) (x_lo / 2 ^ 128) = + (res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128 := by + rw [evmOr_eq' _ _ hshl_wm hshr_wm, + or_eq_add_shl (res % 2 ^ 128) (x_lo / 2 ^ 128) 128 hxlo_hi] + simp only [hn_or] + -- Abbreviate n_evm for clarity + -- n_evm := (res % 2^128) * 2^128 + x_lo / 2^128 + -- Simplify evmDiv/evmMod on n_evm + have hn_div : evmDiv ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) d = + ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d := + evmDiv_eq' _ d hn_evm_lt hd_pos hd_wm + have hn_mod : evmMod ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) d = + ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d := + evmMod_eq' _ d hn_evm_lt hd_pos hd_wm + -- Simplify evmNot 0 = WORD_MOD - 1 + have hnot_eq : evmNot 0 = WORD_MOD - 1 := + evmNot_eq' 0 (by unfold WORD_MOD; omega) + have hnot_wm : WORD_MOD - 1 < WORD_MOD := by omega + have hwm_div : evmDiv (WORD_MOD - 1) d = (WORD_MOD - 1) / d := + evmDiv_eq' _ d hnot_wm hd_pos hd_wm + have hwm_mod : evmMod (WORD_MOD - 1) d = (WORD_MOD - 1) % d := + evmMod_eq' _ d hnot_wm hd_pos hd_wm + simp only [hn_div, hn_mod, hnot_eq, hwm_div, hwm_mod] + -- Now: evmAdd 1 ((WORD_MOD-1) % d) = 1 + (WORD_MOD-1) % d + have hrw_lt : (WORD_MOD - 1) % d < d := Nat.mod_lt _ hd_pos + have hrw_wm : (WORD_MOD - 1) % d < WORD_MOD := + Nat.lt_of_lt_of_le hrw_lt (by unfold WORD_MOD; omega) + have h1_wm : (1 : Nat) < WORD_MOD := by unfold WORD_MOD; omega + have h1rw_sum : 1 + (WORD_MOD - 1) % d < WORD_MOD := + Nat.lt_of_le_of_lt (by omega : 1 + (WORD_MOD - 1) % d ≤ d) (by unfold WORD_MOD; omega) + have hadd_1_rw : evmAdd 1 ((WORD_MOD - 1) % d) = 1 + (WORD_MOD - 1) % d := + evmAdd_eq' 1 _ h1_wm hrw_wm h1rw_sum + simp only [hadd_1_rw] + -- evmAdd (n_evm%d) (1 + (WORD_MOD-1)%d) = R where R = n_evm%d + 1 + (WORD_MOD-1)%d + have hr0_lt : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d < d := + Nat.mod_lt _ hd_pos + have hr0_wm : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d < WORD_MOD := + Nat.lt_of_lt_of_le hr0_lt (by unfold WORD_MOD; omega) + have hR_sum : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d) < WORD_MOD := + -- r0 < d and 1 + rw < d + 1, so R < 2*d < 2^130 < WORD_MOD + Nat.lt_of_lt_of_le (by omega : _ < 2 * d) (by unfold WORD_MOD; omega) + have hstep2 : evmAdd (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d) + (1 + (WORD_MOD - 1) % d) = + ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + (1 + (WORD_MOD - 1) % d) := + evmAdd_eq' _ _ hr0_wm h1rw_sum hR_sum + simp only [hstep2] + -- Abbreviate R = n_evm%d + 1 + (WORD_MOD-1)%d + -- evmDiv R d = R / d, evmMod R d = R % d + have hR_lt2d : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d) < 2 * d := by omega + have hR_wm : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d) < WORD_MOD := hR_sum + have hdiv_R : evmDiv (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) d = + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) / d := + evmDiv_eq' _ d hR_wm hd_pos hd_wm + have hmod_R : evmMod (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) d = + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) % d := + evmMod_eq' _ d hR_wm hd_pos hd_wm + simp only [hdiv_R, hmod_R] + -- evmAdd (n_evm/d) ((WORD_MOD-1)/d) = n_evm/d + (WORD_MOD-1)/d + have hq0_wm : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d < WORD_MOD := by + unfold WORD_MOD; exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hn_evm_lt + have hqw_wm : (WORD_MOD - 1) / d < WORD_MOD := by + unfold WORD_MOD; exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hnot_wm + -- Tighter bounds: q0 < 2^128 and qw < 2^128 (from n < 2^256 and d ≥ 2^128) + have hq0_128 : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d < 2 ^ 128 := + (Nat.div_lt_iff_lt_mul hd_pos).mpr (Nat.lt_of_lt_of_le hn_evm_lt + (by rw [h_wm_sq]; exact Nat.mul_le_mul_left _ hd_ge)) + have hqw_128 : (WORD_MOD - 1) / d < 2 ^ 128 := + (Nat.div_lt_iff_lt_mul hd_pos).mpr (Nat.lt_of_lt_of_le hnot_wm + (by rw [h_wm_sq]; exact Nat.mul_le_mul_left _ hd_ge)) + have h129_le_wm : (2 : Nat) ^ 129 ≤ WORD_MOD := by + unfold WORD_MOD; exact Nat.pow_le_pow_right (by omega) (by omega) + have hq0qw_sum : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + + (WORD_MOD - 1) / d < WORD_MOD := + Nat.lt_of_lt_of_le (by omega : _ < 2 ^ 129) h129_le_wm + have hstep1 : evmAdd (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d) + ((WORD_MOD - 1) / d) = + ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + (WORD_MOD - 1) / d := + evmAdd_eq' _ _ hq0_wm hqw_wm hq0qw_sum + simp only [hstep1] + -- evmAdd (q0+qw) (R/d) = q0+qw+R/d + have hR_div_le1 : (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) / d ≤ 1 := + Nat.lt_succ_iff.mp ((Nat.div_lt_iff_lt_mul hd_pos).mpr hR_lt2d) + have hR_div_wm : (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) / d < WORD_MOD := + Nat.lt_of_le_of_lt hR_div_le1 (by unfold WORD_MOD; omega) + have hfinal_sum : ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + + (WORD_MOD - 1) / d + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) / d < WORD_MOD := + Nat.lt_of_lt_of_le (by omega : _ < 2 ^ 129 + 1) (by omega : 2 ^ 129 + 1 ≤ WORD_MOD) + have hstep3 : evmAdd (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + + (WORD_MOD - 1) / d) + ((((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) / d) = + ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + (WORD_MOD - 1) / d + + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) / d := + evmAdd_eq' _ _ hq0qw_sum hR_div_wm hfinal_sum + simp only [hstep3] + -- === Now the goal is pure Nat === + -- Show these equal n_full/d and n_full%d via the carry correction identity + -- n_full = n_evm + WORD_MOD where n_evm = (res%2^128)*2^128 + x_lo/2^128 + have hc_one : res / 2 ^ 128 = 1 := by + have hc_pos : 0 < res / 2 ^ 128 := Nat.pos_of_ne_zero hc_ne + have hc_le : res / 2 ^ 128 ≤ 1 := by + have : res / 2 ^ 128 < 2 := (Nat.div_lt_iff_lt_mul (Nat.two_pow_pos 128)).mpr (by omega) + omega + omega + have hn_full_eq : n_full = + (res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128 + WORD_MOD := by + show res * 2 ^ 128 + x_lo / 2 ^ 128 = + (res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128 + WORD_MOD + have h := Nat.div_add_mod res (2 ^ 128); rw [hc_one] at h; rw [h_wm_sq]; omega + -- n_full = d * (q0 + qw) + R + -- where q0 = n_evm/d, qw = (WORD_MOD-1)/d, R = n_evm%d + 1 + (WORD_MOD-1)%d + have hn_full_decomp : n_full = + d * (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + (WORD_MOD - 1) / d) + + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + (1 + (WORD_MOD - 1) % d)) := by + rw [hn_full_eq] + have h1 := (Nat.div_add_mod ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) d).symm + have h2 := (Nat.div_add_mod (WORD_MOD - 1) d).symm + rw [Nat.mul_add]; omega + -- Apply div_of_mul_add and mod_of_mul_add + rw [show (2 : Nat) ^ 256 = WORD_MOD from rfl] + have hn_div : n_full / d = + ((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) / d + (WORD_MOD - 1) / d + + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + (1 + (WORD_MOD - 1) % d)) / d := by + rw [hn_full_decomp]; exact div_of_mul_add d _ _ hd_pos + have hn_mod : n_full % d = + (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + (1 + (WORD_MOD - 1) % d)) % d := by + rw [hn_full_decomp]; exact mod_of_mul_add d _ _ + have hn_full_mod_wm : n_full % d < WORD_MOD := + Nat.lt_of_lt_of_le (Nat.mod_lt n_full hd_pos) (by unfold WORD_MOD; omega) + refine ⟨?_, ?_⟩ + · rw [hn_div]; exact (Nat.mod_eq_of_lt hfinal_sum).symm + · rw [hn_mod] + have : (((res % 2 ^ 128) * 2 ^ 128 + x_lo / 2 ^ 128) % d + + (1 + (WORD_MOD - 1) % d)) % d < WORD_MOD := + Nat.lt_of_lt_of_le (Nat.mod_lt _ hd_pos) (by unfold WORD_MOD; omega) + exact (Nat.mod_eq_of_lt this).symm + · -- NO CARRY case + next hc_not => + have hc_zero : res / 2 ^ 128 = 0 := Decidable.byContradiction hc_not + have hres_128 : res < 2 ^ 128 := by + suffices ¬(2 ^ 128 ≤ res) by omega + intro h; exact absurd hc_zero (Nat.ne_of_gt (Nat.div_pos h (Nat.two_pow_pos 128))) + have hmod_res : res % 2 ^ 128 = res := Nat.mod_eq_of_lt hres_128 + -- Simplify evmOr to Nat addition = n_full + have hn_or : evmOr (res % 2 ^ 128 * 2 ^ 128) (x_lo / 2 ^ 128) = n_full := by + rw [evmOr_eq' _ _ hshl_wm hshr_wm, + or_eq_add_shl (res % 2 ^ 128) (x_lo / 2 ^ 128) 128 hxlo_hi, hmod_res] + -- n_full < WORD_MOD (n_full = res*2^128 + x_lo/2^128, and res%2^128 = res) + have hn_full_wm : n_full < WORD_MOD := by + show res * 2 ^ 128 + x_lo / 2 ^ 128 < WORD_MOD + rw [← hmod_res]; exact hn_evm_lt + -- Reduce .fst/.snd, rewrite evmOr, simplify evmDiv/evmMod + simp only [hn_or] + rw [evmDiv_eq' n_full d hn_full_wm hd_pos hd_wm, + evmMod_eq' n_full d hn_full_wm hd_pos hd_wm, + show (2 : Nat) ^ 256 = WORD_MOD from rfl] + exact ⟨(Nat.mod_eq_of_lt (Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hn_full_wm)).symm, + (Nat.mod_eq_of_lt (Nat.lt_of_lt_of_le (Nat.mod_lt n_full hd_pos) (by unfold WORD_MOD; omega))).symm⟩ + +/-- Sub-lemma D: model_sqrtCorrection_evm computes the raw correction step. + Given r_hi (high sqrt), r_lo (Karatsuba quotient), rem (Karatsuba remainder), x_lo: + result = r_hi * 2^128 + r_lo - (if rem * 2^128 + x_lo % 2^128 < r_lo * r_lo then 1 else 0) + The 257-bit split comparison correctly evaluates rem*2^128 + x_lo_lo < r_lo^2. -/ +private theorem model_sqrtCorrection_evm_correct + (r_hi r_lo rem x_lo : Nat) + (hrhi_lo : 2 ^ 127 ≤ r_hi) (hrhi_hi : r_hi < 2 ^ 128) + (hrlo_le : r_lo ≤ 2 ^ 128) (hrem : rem < 2 * r_hi) + (hxlo : x_lo < 2 ^ 256) + (hedge : r_lo = 2 ^ 128 → rem < 2 ^ 128) : + model_sqrtCorrection_evm r_hi r_lo rem x_lo = + r_hi * 2 ^ 128 + r_lo - + (if rem * 2 ^ 128 + x_lo % 2 ^ 128 < r_lo * r_lo then 1 else 0) := by + have hrhi_wm : r_hi < WORD_MOD := by unfold WORD_MOD; omega + have hrlo_wm : r_lo < WORD_MOD := by unfold WORD_MOD; omega + have hrem_wm : rem < WORD_MOD := by unfold WORD_MOD; omega + have hxlo_wm : x_lo < WORD_MOD := hxlo + have hrem_129 : rem < 2 ^ 129 := by omega + have h_wm_sq : WORD_MOD = 2 ^ 128 * 2 ^ 128 := by unfold WORD_MOD; rw [← Nat.pow_add] + -- Constant-fold: evmAnd(evmAnd(128, 255), 255) = 128 + have hcf128 : evmAnd (evmAnd 128 255) 255 = 128 := by decide + -- 340282366920938463463374607431768211455 = 2^128 - 1 + have hmask : (340282366920938463463374607431768211455 : Nat) = 2 ^ 128 - 1 := by decide + -- Unfold and inline let-bindings + unfold model_sqrtCorrection_evm + dsimp only + simp only [u256_id' r_hi hrhi_wm, u256_id' r_lo hrlo_wm, u256_id' rem hrem_wm, + u256_id' x_lo hxlo_wm, hcf128, hmask] + -- === Simplify each EVM operation === + -- evmShl 128 r_hi, evmShr 128 rem, evmShr 128 r_lo + have hshl_rhi : evmShl 128 r_hi = r_hi * 2 ^ 128 := by + rw [evmShl_eq' 128 r_hi (by omega) hrhi_wm] + exact Nat.mod_eq_of_lt (by rw [h_wm_sq]; exact Nat.mul_lt_mul_of_pos_right hrhi_hi (Nat.two_pow_pos 128)) + have hshr_rem : evmShr 128 rem = rem / 2 ^ 128 := evmShr_eq' 128 rem (by omega) hrem_wm + have hshr_rlo : evmShr 128 r_lo = r_lo / 2 ^ 128 := evmShr_eq' 128 r_lo (by omega) hrlo_wm + -- evmShl 128 rem = (rem % 2^128) * 2^128 + have hshl_rem : evmShl 128 rem = (rem % 2 ^ 128) * 2 ^ 128 := by + rw [evmShl_eq' 128 rem (by omega) hrem_wm]; exact mul_pow128_mod_word rem + -- evmAnd x_lo (2^128-1) = x_lo % 2^128 + have hand_mask : evmAnd x_lo (2 ^ 128 - 1) = x_lo % (2 ^ 128) := by + rw [evmAnd_eq' x_lo (2 ^ 128 - 1) hxlo_wm (by unfold WORD_MOD; omega)] + exact Nat.and_two_pow_sub_one_eq_mod x_lo 128 + -- evmOr for res_lo_concat: (rem%2^128)*2^128 + x_lo%2^128 + have hshl_rem_wm : (rem % 2 ^ 128) * 2 ^ 128 < WORD_MOD := by + rw [h_wm_sq]; exact Nat.mul_lt_mul_of_pos_right (Nat.mod_lt rem (Nat.two_pow_pos 128)) (Nat.two_pow_pos 128) + have hxlo_mod_lt : x_lo % 2 ^ 128 < 2 ^ 128 := Nat.mod_lt x_lo (Nat.two_pow_pos 128) + have hxlo_mod_wm : x_lo % 2 ^ 128 < WORD_MOD := by unfold WORD_MOD; omega + have hor_concat : evmOr (evmShl 128 rem) (evmAnd x_lo (2 ^ 128 - 1)) = + (rem % 2 ^ 128) * 2 ^ 128 + x_lo % 2 ^ 128 := by + rw [hshl_rem, hand_mask, evmOr_eq' _ _ hshl_rem_wm hxlo_mod_wm, + or_eq_add_shl (rem % 2 ^ 128) (x_lo % 2 ^ 128) 128 hxlo_mod_lt] + -- evmMul r_lo r_lo = (r_lo * r_lo) % WORD_MOD + have hmul_rlo : evmMul r_lo r_lo = (r_lo * r_lo) % WORD_MOD := + evmMul_eq' r_lo r_lo hrlo_wm hrlo_wm + -- evmLt, evmEq simplifications + have hrem_hi_wm : rem / 2 ^ 128 < WORD_MOD := by unfold WORD_MOD; omega + have hrlo_hi_wm : r_lo / 2 ^ 128 < WORD_MOD := by unfold WORD_MOD; omega + have hconcat_wm : (rem % 2 ^ 128) * 2 ^ 128 + x_lo % 2 ^ 128 < WORD_MOD := by + rw [h_wm_sq]; omega + have hmul_wm : (r_lo * r_lo) % WORD_MOD < WORD_MOD := Nat.mod_lt _ (by unfold WORD_MOD; omega) + have hlt_hi : evmLt (evmShr 128 rem) (evmShr 128 r_lo) = + if rem / 2 ^ 128 < r_lo / 2 ^ 128 then 1 else 0 := by + rw [hshr_rem, hshr_rlo]; exact evmLt_eq' _ _ hrem_hi_wm hrlo_hi_wm + have heq_hi : evmEq (evmShr 128 rem) (evmShr 128 r_lo) = + if rem / 2 ^ 128 = r_lo / 2 ^ 128 then 1 else 0 := by + rw [hshr_rem, hshr_rlo]; exact evmEq_eq' _ _ hrem_hi_wm hrlo_hi_wm + have hlt_lo : evmLt (evmOr (evmShl 128 rem) (evmAnd x_lo (2 ^ 128 - 1))) (evmMul r_lo r_lo) = + if (rem % 2 ^ 128) * 2 ^ 128 + x_lo % 2 ^ 128 < (r_lo * r_lo) % WORD_MOD then 1 else 0 := by + rw [hor_concat, hmul_rlo]; exact evmLt_eq' _ _ hconcat_wm hmul_wm + -- Combine the comparison: cmp = evmOr(lt_hi, evmAnd(eq_hi, lt_lo)) + simp only [hlt_hi, heq_hi, hlt_lo, hshl_rhi] + -- Now simplify evmAnd/evmOr on {0,1} comparison results, and evmAdd/evmSub + -- The goal has the form: + -- evmSub (evmAdd (r_hi*2^128) r_lo) + -- (evmOr (if rem_hi < rlo_hi then 1 else 0) + -- (evmAnd (if rem_hi = rlo_hi then 1 else 0) + -- (if res_lo_concat < rlo_sq_mod then 1 else 0))) + -- = r_hi*2^128 + r_lo - (if rem*2^128 + x_lo%2^128 < r_lo*r_lo then 1 else 0) + -- + -- Key: the 257-bit comparison via (hi_lt || (hi_eq && lo_lt)) correctly evaluates + -- rem * 2^128 + x_lo % 2^128 < r_lo * r_lo + -- where: + -- rem_hi = rem / 2^128, rlo_hi = r_lo / 2^128 + -- LHS_lo = (rem % 2^128) * 2^128 + x_lo % 2^128 + -- RHS_lo = (r_lo * r_lo) % WORD_MOD + -- LHS = rem_hi * WORD_MOD + LHS_lo, RHS = rlo_hi * WORD_MOD + RHS_lo (conceptually) + -- Since rem < 2*r_hi < 2^129, rem_hi ∈ {0,1} + -- Since r_lo ≤ 2^128, rlo_hi ∈ {0,1} + -- rem / 2^128 ≤ 1 + have hrem_hi_le : rem / 2 ^ 128 ≤ 1 := + Nat.lt_succ_iff.mp ((Nat.div_lt_iff_lt_mul (Nat.two_pow_pos 128)).mpr (by omega)) + -- r_lo / 2^128 ≤ 1 + have hrlo_hi_le : r_lo / 2 ^ 128 ≤ 1 := by + have : r_lo / 2 ^ 128 < 2 := (Nat.div_lt_iff_lt_mul (Nat.two_pow_pos 128)).mpr (by omega) + omega + -- === Simplify evmAnd/evmOr on {0,1} values === + -- evmAnd (if a then 1 else 0) (if b then 1 else 0) = + -- if a ∧ b then 1 else 0 + -- evmOr (if a then 1 else 0) (if b then 1 else 0) = + -- if a ∨ b then 1 else 0 + -- These follow because the values are 0 or 1, which are < WORD_MOD + -- After expanding: the cmp value is (if (rem_hi < rlo_hi) ∨ (rem_hi = rlo_hi ∧ lo_lt) then 1 else 0) + -- where lo_lt = ((rem%2^128)*2^128 + x_lo%2^128 < (r_lo*r_lo)%WORD_MOD) + -- + -- Need: this equals (if rem*2^128 + x_lo%2^128 < r_lo*r_lo then 1 else 0) + -- This is the 257-bit comparison correctness. + -- + -- And then: evmSub(evmAdd(r_hi*2^128, r_lo), cmp) = r_hi*2^128 + r_lo - cmp + -- Case split on rem / 2^128 and r_lo / 2^128 + have hrem_hi_cases : rem / 2 ^ 128 = 0 ∨ rem / 2 ^ 128 = 1 := by omega + have hrlo_hi_cases : r_lo / 2 ^ 128 = 0 ∨ r_lo / 2 ^ 128 = 1 := by omega + rcases hrem_hi_cases with hremh | hremh <;> rcases hrlo_hi_cases with hrloh | hrloh <;> + simp only [hremh, hrloh] + · -- Case (0,0): rem < 2^128, r_lo < 2^128 + -- Reduce: if 0 < 0 → 0, if True → 1 + have h00 : (if (0 : Nat) < 0 then 1 else 0) = 0 := by decide + simp only [h00, ite_true] + -- evmOr 0 (evmAnd 1 x) where x = if P then 1 else 0 + -- evmAnd 1 (if P then 1 else 0) = if P then 1 else 0 + have hand1 : ∀ (n : Nat), n ≤ 1 → + evmAnd 1 n = n := by + intro n hn; rcases Nat.le_one_iff_eq_zero_or_eq_one.mp hn with rfl | rfl <;> decide + -- evmOr 0 n = n for n ≤ 1 + have hor0 : ∀ (n : Nat), n ≤ 1 → evmOr 0 n = n := by + intro n hn; rcases Nat.le_one_iff_eq_zero_or_eq_one.mp hn with rfl | rfl <;> decide + -- Simplify: rem < 2^128 → rem % 2^128 = rem + have hrem_lt : rem < 2 ^ 128 := by omega + have hrem_mod : rem % 2 ^ 128 = rem := Nat.mod_eq_of_lt hrem_lt + -- r_lo < 2^128 → r_lo * r_lo < WORD_MOD → r_lo*r_lo % WORD_MOD = r_lo*r_lo + have hrlo_lt : r_lo < 2 ^ 128 := by omega + have hrlo_sq_lt : r_lo * r_lo < WORD_MOD := by + have := Nat.mul_le_mul_left r_lo (show r_lo ≤ 2 ^ 128 from by omega) + have := Nat.mul_lt_mul_of_pos_right hrlo_lt (Nat.two_pow_pos 128) + rw [h_wm_sq]; omega + have hmod_sq : r_lo * r_lo % WORD_MOD = r_lo * r_lo := Nat.mod_eq_of_lt hrlo_sq_lt + rw [hrem_mod, hmod_sq] + -- Now both comparisons match, simplify evmAnd/evmOr + have hcmp_le : (if rem * 2 ^ 128 + x_lo % 2 ^ 128 < r_lo * r_lo then 1 else (0 : Nat)) ≤ 1 := by + split <;> omega + rw [hand1 _ hcmp_le, hor0 _ hcmp_le] + -- Simplify evmAdd/evmSub + have hrhi_mul_lt : r_hi * 2 ^ 128 < WORD_MOD := by + rw [h_wm_sq]; exact Nat.mul_lt_mul_of_pos_right hrhi_hi (Nat.two_pow_pos 128) + have hadd_lt : r_hi * 2 ^ 128 + r_lo < WORD_MOD := by omega + have hcmp_le_sum : (if rem * 2 ^ 128 + x_lo % 2 ^ 128 < r_lo * r_lo then 1 else 0) ≤ + r_hi * 2 ^ 128 + r_lo := by + split <;> omega + rw [evmAdd_eq' _ _ hrhi_mul_lt hrlo_wm hadd_lt, + evmSub_eq_of_le _ _ hadd_lt hcmp_le_sum] + · -- Case (0,1): rem < 2^128, r_lo / 2^128 = 1 → r_lo = 2^128 + have hrlo_eq : r_lo = 2 ^ 128 := by omega + -- Reduce ifs + have h01a : (if (0 : Nat) < 1 then 1 else 0) = 1 := by decide + have h01b : (if (0 : Nat) = 1 then 1 else 0) = 0 := by decide + simp only [h01a, h01b] + -- evmAnd 0 _ = 0 + have hand0 : ∀ x, evmAnd 0 x = 0 := by + intro x; unfold evmAnd u256; simp + simp only [hand0] + -- evmOr 1 0 = 1 + have : evmOr 1 0 = 1 := by decide + simp only [this] + -- RHS comparison is true: rem*2^128 + x_lo%2^128 < 2^128*2^128 = 2^256 + have hrem_lt : rem < 2 ^ 128 := by omega + have hcmp_true : rem * 2 ^ 128 + x_lo % 2 ^ 128 < r_lo * r_lo := by + rw [hrlo_eq, show (2 : Nat) ^ 128 * 2 ^ 128 = 2 ^ 256 from by rw [← Nat.pow_add]] + have := Nat.mod_lt x_lo (Nat.two_pow_pos 128); omega + simp only [hcmp_true, ↓reduceIte] + rw [hrlo_eq] + -- Now: evmSub (evmAdd (r_hi * 2^128) (2^128)) 1 = r_hi * 2^128 + 2^128 - 1 + -- Two subcases: overflow or not + by_cases hoverflow : r_hi * 2 ^ 128 + 2 ^ 128 < WORD_MOD + · -- No overflow + rw [evmAdd_eq' _ _ (by omega) (by unfold WORD_MOD; omega) hoverflow, + evmSub_eq_of_le _ 1 hoverflow (by omega)] + · -- Overflow: r_hi * 2^128 + 2^128 = WORD_MOD + have hsum_eq : r_hi * 2 ^ 128 + 2 ^ 128 = WORD_MOD := by + have : r_hi * 2 ^ 128 + 2 ^ 128 ≤ WORD_MOD := by + rw [h_wm_sq, ← Nat.succ_mul] + exact Nat.mul_le_mul_right _ hrhi_hi + omega + rw [evmSub_evmAdd_eq_of_overflow _ _ (by omega) (by unfold WORD_MOD; omega) hsum_eq] + omega + · -- Case (1,0): rem / 2^128 = 1, r_lo < 2^128 + -- Reduce if 1 < 0 → 0, if 1 = 0 → 0 + have h10a : (if (1 : Nat) < 0 then 1 else 0) = 0 := by decide + have h10b : (if (1 : Nat) = 0 then 1 else 0) = 0 := by decide + simp only [h10a, h10b] + -- evmAnd 0 _ = 0 + have hand0 : ∀ x, evmAnd 0 x = 0 := by + intro x; unfold evmAnd u256; simp + simp only [hand0] + -- evmOr 0 0 = 0 + have hor00 : evmOr 0 0 = 0 := by decide + simp only [hor00] + -- RHS comparison: rem ≥ 2^128, r_lo < 2^128 → comparison false + have hrlo_lt : r_lo < 2 ^ 128 := by omega + have hrlo_sq_lt : r_lo * r_lo < WORD_MOD := by + have h1 := Nat.mul_le_mul_left r_lo (show r_lo ≤ 2 ^ 128 from by omega) + have h2 := Nat.mul_lt_mul_of_pos_right (show r_lo < 2 ^ 128 from by omega) (Nat.two_pow_pos 128) + rw [h_wm_sq]; omega + have hcmp_false : ¬(rem * 2 ^ 128 + x_lo % 2 ^ 128 < r_lo * r_lo) := by omega + simp only [hcmp_false, ↓reduceIte, Nat.sub_zero] + -- Simplify evmSub (evmAdd ...) 0 + have hrhi_mul_lt : r_hi * 2 ^ 128 < WORD_MOD := by + rw [h_wm_sq]; exact Nat.mul_lt_mul_of_pos_right hrhi_hi (Nat.two_pow_pos 128) + have hadd_lt : r_hi * 2 ^ 128 + r_lo < WORD_MOD := by omega + rw [evmAdd_eq' _ _ hrhi_mul_lt hrlo_wm hadd_lt, + evmSub_eq_of_le _ 0 hadd_lt (Nat.zero_le _)] + omega + · -- Case (1,1): contradiction (hedge + rem ≥ 2^128 + r_lo = 2^128) + -- r_lo / 2^128 = 1, r_lo ≤ 2^128 → r_lo = 2^128 + have hrlo_eq : r_lo = 2 ^ 128 := by omega + -- hedge: r_lo = 2^128 → rem < 2^128 + have hrem_lt : rem < 2 ^ 128 := hedge hrlo_eq + -- But rem / 2^128 = 1 → rem ≥ 2^128 + exfalso; omega + +/-- Composition of the three EVM sub-models equals karatsubaFloor on normalized inputs. -/ +private theorem evm_composition_eq_karatsubaFloor (x_hi_1 x_lo_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) (hxlo : x_lo_1 < 2 ^ 256) : + model_sqrtCorrection_evm + (model_innerSqrt_evm x_hi_1).1 + (model_karatsubaQuotient_evm (model_innerSqrt_evm x_hi_1).2 x_lo_1 + (model_innerSqrt_evm x_hi_1).1).1 + (model_karatsubaQuotient_evm (model_innerSqrt_evm x_hi_1).2 x_lo_1 + (model_innerSqrt_evm x_hi_1).1).2 + x_lo_1 + = karatsubaFloor x_hi_1 x_lo_1 := by + -- Step 1: Replace EVM sub-model outputs with their Nat equivalents + have hinner := model_innerSqrt_evm_correct x_hi_1 hlo hhi + rw [hinner.1, hinner.2] + -- Abbreviations used in comments: + -- m = natSqrt x_hi_1, res = x_hi_1 - m², H = 2^128 + -- n = res*H + x_lo_1/H, d = 2*m, q = n/d, rem = n%d, r = m*H + q + -- Step 2: Bounds on natSqrt x_hi_1 and residue + have hrhi_lo : 2 ^ 127 ≤ natSqrt x_hi_1 := natSqrt_ge_2_127 x_hi_1 hlo + have hrhi_hi : natSqrt x_hi_1 < 2 ^ 128 := natSqrt_lt_2_128 x_hi_1 hhi + have hres_le : x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1 ≤ 2 * natSqrt x_hi_1 := by + have hsq := natSqrt_sq_le x_hi_1 + have hsucc := natSqrt_lt_succ_sq x_hi_1 + -- (m+1)*(m+1) = m*m + 2*m + 1, so x - m*m ≤ 2*m + have := Nat.add_mul (natSqrt x_hi_1) 1 (natSqrt x_hi_1 + 1) + have := Nat.mul_add (natSqrt x_hi_1) (natSqrt x_hi_1) 1 + omega + have hres_lt : x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1 < 2 ^ 256 := by omega + -- Step 3: Apply model_karatsubaQuotient_evm_correct + have hkq := model_karatsubaQuotient_evm_correct + (x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) x_lo_1 (natSqrt x_hi_1) + hres_le hxlo hrhi_lo hrhi_hi hres_lt + rw [hkq.1, hkq.2] + -- Step 4: Strip % 2^256 (both q and rem fit in 256 bits) + have hd_pos : 0 < 2 * natSqrt x_hi_1 := by omega + have hxlo_hi : x_lo_1 / 2 ^ 128 < 2 ^ 128 := + Nat.div_lt_of_lt_mul (by rw [← Nat.pow_add]; exact hxlo) + have hq_le : + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1) ≤ 2 ^ 128 := by + rw [Nat.div_le_iff_le_mul_add_pred hd_pos]; omega + have hq_lt_256 : + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1) < 2 ^ 256 := by omega + have hrem_lt_256 : + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) % + (2 * natSqrt x_hi_1) < 2 ^ 256 := + Nat.lt_of_lt_of_le (Nat.mod_lt _ hd_pos) (by omega) + rw [Nat.mod_eq_of_lt hq_lt_256, Nat.mod_eq_of_lt hrem_lt_256] + -- Step 5: Hedge condition: q = 2^128 → rem < 2^128 + have hedge : + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1) = 2 ^ 128 → + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) % + (2 * natSqrt x_hi_1) < 2 ^ 128 := by + intro hq_eq + have hid := (Nat.div_add_mod + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) + (2 * natSqrt x_hi_1)).symm + rw [hq_eq] at hid -- d * 2^128 + rem = n + have hres_eq_d : x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1 = 2 * natSqrt x_hi_1 := by omega + rw [hres_eq_d, show 2 * natSqrt x_hi_1 * 2 ^ 128 + x_lo_1 / 2 ^ 128 = + x_lo_1 / 2 ^ 128 + 2 ^ 128 * (2 * natSqrt x_hi_1) from by omega, + Nat.add_mul_mod_self_right, + Nat.mod_eq_of_lt (by omega : x_lo_1 / 2 ^ 128 < 2 * natSqrt x_hi_1)] + exact hxlo_hi + -- Step 6: Apply model_sqrtCorrection_evm_correct + have hcorr := model_sqrtCorrection_evm_correct (natSqrt x_hi_1) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) % + (2 * natSqrt x_hi_1)) + x_lo_1 hrhi_lo hrhi_hi hq_le (Nat.mod_lt _ hd_pos) hxlo hedge + rw [hcorr] + -- Goal: m*H + q - (if rem*H + x_lo_lo < q*q then 1 else 0) = karatsubaFloor x_hi_1 x_lo_1 + -- Step 7: Both sides equal natSqrt(x_hi_1*2^256+x_lo_1) + rw [karatsubaFloor_eq_natSqrt x_hi_1 x_lo_1 hlo hxlo] + -- Goal: m*H + q - correction = natSqrt(x_hi_1*2^256+x_lo_1) + -- Step 8: r = karatsubaR x_hi_1 x_lo_1 + have hr_eq : natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1) = karatsubaR x_hi_1 x_lo_1 := by + unfold karatsubaR; rfl + -- Step 9: x_full = x_hi_1*2^256 + x_lo_1 + have hx_full : x_hi_1 * (2 ^ 128 * 2 ^ 128) + x_lo_1 / 2 ^ 128 * 2 ^ 128 + + x_lo_1 % 2 ^ 128 = x_hi_1 * 2 ^ 256 + x_lo_1 := by + have h1 := Nat.div_add_mod x_lo_1 (2 ^ 128) + have h2 : (2 : Nat) ^ 128 * 2 ^ 128 = 2 ^ 256 := by rw [← Nat.pow_add] + omega + -- Step 10: karatsubaR_bracket gives natSqrt(x) ≤ r ≤ natSqrt(x)+1 + have hbracket := karatsubaR_bracket x_hi_1 x_lo_1 hlo hxlo + have hbr1 : natSqrt (x_hi_1 * 2 ^ 256 + x_lo_1) ≤ karatsubaR x_hi_1 x_lo_1 := by + have := hbracket.1; rwa [hx_full] at this + have hbr2 : karatsubaR x_hi_1 x_lo_1 ≤ natSqrt (x_hi_1 * 2 ^ 256 + x_lo_1) + 1 := by + have := hbracket.2; rwa [hx_full] at this + -- Step 11: correction_correct gives the answer + have hcc := correction_correct (x_hi_1 * 2 ^ 256 + x_lo_1) (karatsubaR x_hi_1 x_lo_1) hbr1 hbr2 + -- hcc: (if x < r*r then r-1 else r) = natSqrt(x) + rw [← hr_eq] at hcc + -- Step 12: Karatsuba identity x_full + q² = r² + rem*H + x_lo_lo + -- Abbreviate for readability + have hident : + let n := (x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128 + let q := n / (2 * natSqrt x_hi_1) + let rem := n % (2 * natSqrt x_hi_1) + let r := natSqrt x_hi_1 * 2 ^ 128 + q + x_hi_1 * (2 ^ 128 * 2 ^ 128) + x_lo_1 / 2 ^ 128 * 2 ^ 128 + x_lo_1 % 2 ^ 128 + + q * q = r * r + rem * 2 ^ 128 + x_lo_1 % 2 ^ 128 := by + simp only [] + -- Both sides equal m²*(H*H) + n*H + q² + x_lo%H after expansion. + -- Key facts for the algebraic identity: + have hsq_le := natSqrt_sq_le x_hi_1 + have hdivmod := (Nat.div_add_mod + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) + (2 * natSqrt x_hi_1)).symm + -- Product decompositions (making the identity linear for omega): + -- (m²+res)*(H*H) = m²*(H*H) + res*(H*H) + have hp1 := Nat.add_mul (natSqrt x_hi_1 * natSqrt x_hi_1) + (x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) (2 ^ 128 * 2 ^ 128 : Nat) + -- res*(H*H) = res*H*H + have hp2 := Nat.mul_assoc (x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) + (2 ^ 128 : Nat) (2 ^ 128 : Nat) + -- (res*H + x_lo_hi)*H = res*H*H + x_lo_hi*H + have hp3 := Nat.add_mul ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128) + (x_lo_1 / 2 ^ 128) (2 ^ 128 : Nat) + -- Square expansion: (m*H+q)*(m*H+q) = m*H*(m*H+q) + q*(m*H+q) + have hp4 := Nat.add_mul (natSqrt x_hi_1 * 2 ^ 128) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + (natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + -- m*H*(m*H+q) = m*H*(m*H) + m*H*q + have hp5 := Nat.mul_add (natSqrt x_hi_1 * 2 ^ 128) + (natSqrt x_hi_1 * 2 ^ 128) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + -- q*(m*H+q) = q*(m*H) + q*q + have hp6 := Nat.mul_add + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + (natSqrt x_hi_1 * 2 ^ 128) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + -- (a*b)*(a*b) = (a*a)*(b*b) for the m*H*(m*H) term + have hp7 : natSqrt x_hi_1 * 2 ^ 128 * (natSqrt x_hi_1 * 2 ^ 128) = + natSqrt x_hi_1 * natSqrt x_hi_1 * (2 ^ 128 * 2 ^ 128) := by + calc natSqrt x_hi_1 * 2 ^ 128 * (natSqrt x_hi_1 * 2 ^ 128) + = natSqrt x_hi_1 * (2 ^ 128 * (natSqrt x_hi_1 * 2 ^ 128)) := Nat.mul_assoc _ _ _ + _ = natSqrt x_hi_1 * (natSqrt x_hi_1 * (2 ^ 128 * 2 ^ 128)) := by + congr 1; rw [← Nat.mul_assoc, Nat.mul_comm (2 ^ 128 : Nat) (natSqrt x_hi_1), + Nat.mul_assoc] + _ = natSqrt x_hi_1 * natSqrt x_hi_1 * (2 ^ 128 * 2 ^ 128) := (Nat.mul_assoc _ _ _).symm + -- m*H*q = m*q*H (re-association for the cross terms) + have hp8 : natSqrt x_hi_1 * 2 ^ 128 * + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) = + natSqrt x_hi_1 * + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) * 2 ^ 128 := by + rw [Nat.mul_assoc, Nat.mul_comm (2 ^ 128 : Nat), ← Nat.mul_assoc] + -- m*H*q = q*(m*H) (commutativity) + have hp9 := Nat.mul_comm (natSqrt x_hi_1 * 2 ^ 128) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + -- (d*q+rem)*H = d*q*H + rem*H + have hp10 := Nat.add_mul + (2 * natSqrt x_hi_1 * + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1))) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) % + (2 * natSqrt x_hi_1)) + (2 ^ 128 : Nat) + -- d*q = 2*m*q + have hp11 := Nat.mul_assoc (2 : Nat) (natSqrt x_hi_1) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + -- 2*m*q*H = 2*(m*q*H) + have hp12 := Nat.mul_assoc (2 : Nat) + (natSqrt x_hi_1 * + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1))) + (2 ^ 128 : Nat) + omega + -- Step 13: Apply correction_equiv + have hequiv := correction_equiv + (x_hi_1 * (2 ^ 128 * 2 ^ 128) + x_lo_1 / 2 ^ 128 * 2 ^ 128 + x_lo_1 % 2 ^ 128) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + (natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) % + (2 * natSqrt x_hi_1) * 2 ^ 128) + (x_lo_1 % 2 ^ 128) + hident + -- Step 14: Close by case-splitting on the EVM comparison condition + by_cases hlt : ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) % + (2 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 % 2 ^ 128 < + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) * + (((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) + · -- EVM comparison true → correction = 1 + simp only [hlt, ↓reduceIte] + -- Derive: x < r*r (via hequiv + hx_full) + have hlt_x : x_hi_1 * 2 ^ 256 + x_lo_1 < + (natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) * + (natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) := by + rw [← hx_full]; exact hequiv.mpr hlt + -- Simplify the if in hcc to r-1 + simp only [hlt_x, ↓reduceIte] at hcc + exact hcc + · -- EVM comparison false → correction = 0 + simp only [hlt, ↓reduceIte, Nat.sub_zero] + -- Derive: ¬(x < r*r) + have hlt_x : ¬(x_hi_1 * 2 ^ 256 + x_lo_1 < + (natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1)) * + (natSqrt x_hi_1 * 2 ^ 128 + + ((x_hi_1 - natSqrt x_hi_1 * natSqrt x_hi_1) * 2 ^ 128 + x_lo_1 / 2 ^ 128) / + (2 * natSqrt x_hi_1))) := + fun h => hlt (hequiv.mp (by rwa [hx_full])) + simp only [hlt_x, ↓reduceIte] at hcc + exact hcc + +/-- karatsubaFloor on normalized inputs fits in 256 bits. -/ +theorem karatsubaFloor_lt_word (x_hi_1 x_lo_1 : Nat) + (hlo : 2 ^ 254 ≤ x_hi_1) (hhi : x_hi_1 < 2 ^ 256) (hxlo : x_lo_1 < 2 ^ 256) : + karatsubaFloor x_hi_1 x_lo_1 < WORD_MOD := by + rw [karatsubaFloor_eq_natSqrt x_hi_1 x_lo_1 hlo hxlo, show WORD_MOD = 2 ^ 256 from rfl] + -- natSqrt(x) < 2^256 when x < 2^512 + suffices ¬(2 ^ 256 ≤ natSqrt (x_hi_1 * 2 ^ 256 + x_lo_1)) by omega + intro h + have hsq := natSqrt_sq_le (x_hi_1 * 2 ^ 256 + x_lo_1) + have := Nat.le_trans (Nat.mul_le_mul h h) hsq; omega + +end EvmBridge + +/-- The EVM model computes the same as the algebraic sqrt512. + With the refactored model, model_sqrt512_evm is just normalization + + 3 sub-model calls + un-normalization (~10 let-bindings), so this proof + chains sub-results through karatsubaFloor_eq_natSqrt and natSqrt_shift_div. -/ +theorem model_sqrt512_evm_eq_sqrt512 (x_hi x_lo : Nat) + (hxhi_pos : 0 < x_hi) (hxhi_lt : x_hi < 2 ^ 256) + (hxlo_lt : x_lo < 2 ^ 256) : + Sqrt512GeneratedModel.model_sqrt512_evm x_hi x_lo = + sqrt512 (x_hi * 2 ^ 256 + x_lo) := by + open Sqrt512GeneratedModel in + -- Step 0: sqrt512 takes else branch since x_hi > 0 → x ≥ 2^256 + have hx_ge : ¬(x_hi * 2 ^ 256 + x_lo < 2 ^ 256) := by omega + unfold sqrt512; simp only [hx_ge, ↓reduceIte] + -- Simplify: (x_hi*2^256+x_lo)/2^256 = x_hi + have hx_div : (x_hi * 2 ^ 256 + x_lo) / 2 ^ 256 = x_hi := by + rw [Nat.mul_comm, Nat.mul_add_div (Nat.two_pow_pos 256), + Nat.div_eq_of_lt hxlo_lt, Nat.add_zero] + rw [hx_div] + -- Now: LHS = model_sqrt512_evm x_hi x_lo + -- RHS = karatsubaFloor (x * 4^k / 2^256) (x * 4^k % 2^256) / 2^k + + -- Step 1: Get normalization results + have hnorm := evm_normalization_correct x_hi x_lo hxhi_pos hxhi_lt hxlo_lt + -- Rewrite RHS to use EVM expressions + rw [← hnorm.1, ← hnorm.2.1, ← hnorm.2.2.1] + -- Now RHS uses the same EVM sub-expressions as model_sqrt512_evm + -- Step 2: Unfold model_sqrt512_evm to expose sub-model calls + unfold model_sqrt512_evm; simp only [] + -- Step 3: Rewrite the composition of sub-models to karatsubaFloor + rw [evm_composition_eq_karatsubaFloor _ _ hnorm.2.2.2.1 hnorm.2.2.2.2.1 hnorm.2.2.2.2.2] + -- Step 4: Convert evmShr to division + have hshift_lt : evmShr (evmAnd (evmAnd 1 255) 255) (evmClz (u256 x_hi)) < 256 := by + rw [hnorm.2.2.1]; exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega) + rw [evmShr_eq' _ _ hshift_lt + (karatsubaFloor_lt_word _ _ hnorm.2.2.2.1 hnorm.2.2.2.2.1 hnorm.2.2.2.2.2)] + +set_option exponentiation.threshold 512 in +/-- The EVM model of 512-bit sqrt computes natSqrt. -/ +theorem model_sqrt512_evm_correct (x_hi x_lo : Nat) + (hxhi_pos : 0 < x_hi) (hxhi_lt : x_hi < 2 ^ 256) + (hxlo_lt : x_lo < 2 ^ 256) : + Sqrt512GeneratedModel.model_sqrt512_evm x_hi x_lo = + natSqrt (x_hi * 2 ^ 256 + x_lo) := by + rw [model_sqrt512_evm_eq_sqrt512 x_hi x_lo hxhi_pos hxhi_lt hxlo_lt] + have hx_lt : x_hi * 2 ^ 256 + x_lo < 2 ^ 512 := by + calc x_hi * 2 ^ 256 + x_lo + < 2 ^ 256 * 2 ^ 256 := by + have := Nat.mul_lt_mul_of_pos_right hxhi_lt (Nat.two_pow_pos 256) + omega + _ = 2 ^ 512 := by rw [← Nat.pow_add] + exact sqrt512_correct (x_hi * 2 ^ 256 + x_lo) hx_lt + +end Sqrt512Spec diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/KaratsubaStep.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/KaratsubaStep.lean new file mode 100644 index 000000000..a636be5d7 --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/KaratsubaStep.lean @@ -0,0 +1,255 @@ +/- + Karatsuba decomposition algebra for 512-bit square root. + + All theorems use explicit parameters (no `let` bindings in statements) + to avoid opaqueness issues with Lean 4's `intro` for `let`. +-/ +import SqrtProof.SqrtCorrect + +-- ============================================================================ +-- Helpers +-- ============================================================================ + +private theorem sq_expand (a b : Nat) : + (a + b) * (a + b) = a * a + 2 * a * b + b * b := by + rw [Nat.add_mul, Nat.mul_add, Nat.mul_add, Nat.mul_comm b a] + have : 2 * a * b = a * b + a * b := by rw [Nat.mul_assoc, Nat.two_mul] + omega + +private theorem mul_reassoc (a b : Nat) : a * b * (a * b) = a * a * (b * b) := by + rw [Nat.mul_assoc, Nat.mul_left_comm b a b, ← Nat.mul_assoc] + +private theorem succ_sq (m : Nat) : (m + 1) * (m + 1) = m * m + 2 * m + 1 := by + have := sq_expand m 1; simp [Nat.mul_one] at this; omega + +-- ============================================================================ +-- Part 1: Algebraic identity (explicit parameters, no let) +-- ============================================================================ + +theorem karatsuba_identity + (x_hi x_lo_hi x_lo_lo r_hi H : Nat) + (hres : r_hi * r_hi ≤ x_hi) : + x_hi * (H * H) + x_lo_hi * H + x_lo_lo + + ((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi) * + (((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi)) = + (r_hi * H + ((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi)) * + (r_hi * H + ((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi)) + + ((x_hi - r_hi * r_hi) * H + x_lo_hi) % (2 * r_hi) * H + x_lo_lo := by + -- Both sides equal r_hi^2*(H*H) + n*H + x_lo_lo + q^2 + -- where n = (x_hi - r_hi^2)*H + x_lo_hi + -- Euclidean division: (2*r_hi) * q + rem = n + have heuc := Nat.div_add_mod ((x_hi - r_hi * r_hi) * H + x_lo_hi) (2 * r_hi) + -- Expand square: (r_hi*H + q)^2 = r_hi*H*(r_hi*H) + 2*(r_hi*H)*q + q*q + have hexp := sq_expand (r_hi * H) (((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi)) + -- r_hi*H*(r_hi*H) = r_hi*r_hi*(H*H) + have hreassoc := mul_reassoc r_hi H + -- Step 1: Expand square and reassociate on RHS + rw [hexp, hreassoc] + -- Step 2: Factor LHS: x_hi*(H*H) + x_lo_hi*H = r_hi*r_hi*(H*H) + n*H + have hfact1 : x_hi * (H * H) + x_lo_hi * H = + r_hi * r_hi * (H * H) + ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H := by + -- Work from RHS to LHS + symm + rw [Nat.add_mul, Nat.mul_assoc (x_hi - r_hi * r_hi) H H, + ← Nat.add_assoc, ← Nat.add_mul] + congr 1; congr 1; omega + rw [hfact1] + -- Step 3: Show 2*(r_hi*H)*q + rem*H = n*H and substitute back + -- Helper: 2*(a*b)*c = 2*a*c*b (rearranges to factor out b) + have h_prod_comm : ∀ a b c : Nat, 2 * (a * b) * c = 2 * a * c * b := by + intro a b c + rw [(Nat.mul_assoc 2 a b).symm, Nat.mul_assoc (2 * a) b c, + Nat.mul_comm b c, (Nat.mul_assoc (2 * a) c b).symm] + have hfact2 : + 2 * (r_hi * H) * (((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi)) + + ((x_hi - r_hi * r_hi) * H + x_lo_hi) % (2 * r_hi) * H = + ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H := by + rw [h_prod_comm, ← Nat.add_mul]; congr 1 + rw [← hfact2] + -- Step 4: Both sides have the same atoms, just in different order + simp only [Nat.add_comm, Nat.add_left_comm] + +-- ============================================================================ +-- Part 2: Lower bound +-- ============================================================================ + +theorem rhi_H_le_natSqrt (x_hi x_lo r_hi H : Nat) + (hr_hi : r_hi * r_hi ≤ x_hi) : + r_hi * H ≤ natSqrt (x_hi * (H * H) + x_lo) := by + have hsq : (r_hi * H) * (r_hi * H) ≤ x_hi * (H * H) + x_lo := by + calc (r_hi * H) * (r_hi * H) + = r_hi * r_hi * (H * H) := mul_reassoc r_hi H + _ ≤ x_hi * (H * H) := Nat.mul_le_mul_right _ hr_hi + _ ≤ x_hi * (H * H) + x_lo := Nat.le_add_right _ _ + suffices h : ¬(natSqrt (x_hi * (H * H) + x_lo) < r_hi * H) by omega + intro h + have h1 : natSqrt (x_hi * (H * H) + x_lo) + 1 ≤ r_hi * H := h + have h2 := Nat.mul_le_mul h1 h1 + have h3 := natSqrt_lt_succ_sq (x_hi * (H * H) + x_lo) + omega + +-- ============================================================================ +-- Part 3: Combined bracket (specialized to 512-bit case) +-- ============================================================================ + +private theorem natSqrt_ge_pow127 (x_hi : Nat) (hlo : 2 ^ 254 ≤ x_hi) : + 2 ^ 127 ≤ natSqrt x_hi := by + suffices h : ¬(natSqrt x_hi < 2 ^ 127) by omega + intro h + have h1 : natSqrt x_hi + 1 ≤ 2 ^ 127 := h + have h2 := Nat.mul_le_mul h1 h1 + have h3 := natSqrt_lt_succ_sq x_hi + have h4 : (2 : Nat) ^ 127 * 2 ^ 127 = 2 ^ 254 := by rw [← Nat.pow_add] + omega + +set_option maxRecDepth 4096 in +/-- The Karatsuba bracket for the 512-bit case: natSqrt(x) ≤ r ≤ natSqrt(x) + 1. + Stated with fully expanded terms to avoid let-binding issues. -/ +theorem karatsuba_bracket_512 (x_hi x_lo_hi x_lo_lo : Nat) + (hxhi_lo : 2 ^ 254 ≤ x_hi) + (hxlo_hi : x_lo_hi < 2 ^ 128) (hxlo_lo : x_lo_lo < 2 ^ 128) : + let H : Nat := 2 ^ 128 + let r_hi := natSqrt x_hi + let q := ((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi) + let r := r_hi * H + q + let x := x_hi * (H * H) + x_lo_hi * H + x_lo_lo + natSqrt x ≤ r ∧ r ≤ natSqrt x + 1 := by + intro H r_hi q r x + -- Key natSqrt bounds + have hs_lo := natSqrt_sq_le x -- s² ≤ x + have hs_hi := natSqrt_lt_succ_sq x -- x < (s+1)² + have hr_sq_le := natSqrt_sq_le x_hi -- r_hi² ≤ x_hi + have hr_sq_hi := natSqrt_lt_succ_sq x_hi -- x_hi < (r_hi+1)² + have hr_ge : 2 ^ 127 ≤ r_hi := natSqrt_ge_pow127 x_hi hxhi_lo + have hr_pos : 0 < r_hi := by omega + have hd_pos : 0 < 2 * r_hi := by omega + have hxlo_hi' : x_lo_hi < H := hxlo_hi + have hxlo_lo' : x_lo_lo < H := hxlo_lo + -- r_hi * H ≤ natSqrt(x) + have hlo : r_hi * H ≤ natSqrt x := by + show r_hi * H ≤ natSqrt (x_hi * (H * H) + x_lo_hi * H + x_lo_lo) + rw [Nat.add_assoc] + exact rhi_H_le_natSqrt x_hi (x_lo_hi * H + x_lo_lo) r_hi H hr_sq_le + -- x_lo_hi*H + x_lo_lo < H*H + have hxlo_lt : x_lo_hi * H + x_lo_lo < H * H := by + have := Nat.mul_le_mul_right H (show x_lo_hi + 1 ≤ H from hxlo_hi') + rw [Nat.add_mul, Nat.one_mul] at this; omega + -- natSqrt(x) < (r_hi + 1) * H + have hhi : natSqrt x < (r_hi + 1) * H := by + suffices hx_lt : x < (r_hi + 1) * H * ((r_hi + 1) * H) by + suffices h : ¬((r_hi + 1) * H ≤ natSqrt x) by omega + intro hc; have h2 := Nat.mul_le_mul hc hc; omega + show x < (r_hi + 1) * H * ((r_hi + 1) * H) + rw [mul_reassoc] + have hr_sq_hi' : x_hi < (r_hi + 1) * (r_hi + 1) := hr_sq_hi + calc x + < x_hi * (H * H) + H * H := by omega + _ = (x_hi + 1) * (H * H) := by rw [Nat.add_mul, Nat.one_mul] + _ ≤ (r_hi + 1) * (r_hi + 1) * (H * H) := + Nat.mul_le_mul_right _ (by omega) + -- Key helpers for the bracket proof + have hrhH := mul_reassoc r_hi H + have hs_eq : natSqrt x = r_hi * H + (natSqrt x - r_hi * H) := + (Nat.add_sub_cancel' hlo).symm + have hsx := sq_expand (r_hi * H) (natSqrt x - r_hi * H) + -- s² = r_hi²*H² + 2*(r_hi*H)*e + e² ≤ x + have h_sq_le : r_hi * r_hi * (H * H) + 2 * (r_hi * H) * (natSqrt x - r_hi * H) + + (natSqrt x - r_hi * H) * (natSqrt x - r_hi * H) ≤ x := by + rw [← hrhH, ← hsx, ← hs_eq]; exact hs_lo + -- (s+1)² = r_hi²*H² + 2*(r_hi*H)*(e+1) + (e+1)² > x + have hsx1 := sq_expand (r_hi * H) (natSqrt x - r_hi * H + 1) + have h_sq_hi : x < r_hi * r_hi * (H * H) + + 2 * (r_hi * H) * (natSqrt x - r_hi * H + 1) + + (natSqrt x - r_hi * H + 1) * (natSqrt x - r_hi * H + 1) := by + have h_s1 : natSqrt x + 1 = r_hi * H + (natSqrt x - r_hi * H + 1) := by + rw [← Nat.add_assoc]; congr 1 + rw [← hrhH, ← hsx1, ← h_s1]; exact hs_hi + -- Product rearrangement: 2*(r_hi*H)*e = 2*r_hi*e*H + have h2rhe : ∀ e : Nat, 2 * (r_hi * H) * e = 2 * r_hi * e * H := by + intro e; rw [(Nat.mul_assoc 2 r_hi H).symm, Nat.mul_assoc (2 * r_hi) H e, + Nat.mul_comm H e, (Nat.mul_assoc (2 * r_hi) e H).symm] + -- Algebraic identity: x_hi*H² + x_lo_hi*H = r_hi²*H² + n*H + -- (proved as standalone lemma to avoid 2^128 evaluation during rw) + have hfact_key : x_hi * (H * H) + x_lo_hi * H = + r_hi * r_hi * (H * H) + ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H := by + have h := Nat.add_sub_cancel' hr_sq_le + -- h : r_hi * r_hi + (x_hi - r_hi * r_hi) = x_hi + -- RHS = (r_hi² + (x_hi-r_hi²)) * H² + x_lo_hi*H = x_hi*H² + x_lo_hi*H = LHS + have : x_hi * (H * H) = (r_hi * r_hi + (x_hi - r_hi * r_hi)) * (H * H) := by + rw [h] + -- this : x_hi * (H * H) = (r_hi * r_hi + (x_hi - r_hi * r_hi)) * (H * H) + rw [this, Nat.add_mul, Nat.add_assoc] + have h3 : ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H = + (x_hi - r_hi * r_hi) * (H * H) + x_lo_hi * H := by + rw [Nat.add_mul, Nat.mul_assoc] + rw [h3] + have h_decomp : x = r_hi * r_hi * (H * H) + + ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H + x_lo_lo := by + show x_hi * (H * H) + x_lo_hi * H + x_lo_lo = + r_hi * r_hi * (H * H) + ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H + x_lo_lo + rw [hfact_key] + -- Lower bound: e ≤ q + have h_e_le_q : natSqrt x - r_hi * H ≤ q := by + rw [Nat.le_div_iff_mul_le hd_pos] + -- Strategy: show e*d*H < (n+1)*H, then divide by H + suffices h_mul : (natSqrt x - r_hi * H) * (2 * r_hi) * H < + ((x_hi - r_hi * r_hi) * H + x_lo_hi + 1) * H by + exact Nat.le_of_lt_succ (Nat.lt_of_mul_lt_mul_right h_mul) + -- Rearrange LHS: e*d*H = 2*(r_hi*H)*e + have h_rearr : (natSqrt x - r_hi * H) * (2 * r_hi) * H = + 2 * (r_hi * H) * (natSqrt x - r_hi * H) := by + rw [Nat.mul_comm (natSqrt x - r_hi * H) (2 * r_hi)] + exact (h2rhe (natSqrt x - r_hi * H)).symm + -- 2*(r_hi*H)*e + r_hi²*H² ≤ x (from h_sq_le, dropping e²) + have h_bound : 2 * (r_hi * H) * (natSqrt x - r_hi * H) + + r_hi * r_hi * (H * H) ≤ x := + calc 2 * (r_hi * H) * (natSqrt x - r_hi * H) + r_hi * r_hi * (H * H) + = r_hi * r_hi * (H * H) + 2 * (r_hi * H) * (natSqrt x - r_hi * H) := + Nat.add_comm _ _ + _ ≤ r_hi * r_hi * (H * H) + 2 * (r_hi * H) * (natSqrt x - r_hi * H) + + (natSqrt x - r_hi * H) * (natSqrt x - r_hi * H) := + Nat.le_add_right _ _ + _ ≤ x := h_sq_le + -- e*d*H = 2*(r_hi*H)*e < n*H + H = (n+1)*H + rw [h_rearr, Nat.add_mul, Nat.one_mul] + -- h_bound + h_decomp + hxlo_lo' close the goal + omega + -- Upper bound: q ≤ e + 1 + have h_q_le_e1 : q ≤ natSqrt x - r_hi * H + 1 := by + show ((x_hi - r_hi * r_hi) * H + x_lo_hi) / (2 * r_hi) ≤ + natSqrt x - r_hi * H + 1 + -- Strategy: show n*H < (e+2)*(2*r_hi)*H, then divide by H + suffices h_mul : ((x_hi - r_hi * r_hi) * H + x_lo_hi) * H < + (natSqrt x - r_hi * H + 1 + 1) * (2 * r_hi) * H by + have h2 := Nat.lt_of_mul_lt_mul_right h_mul + -- h2 : n < (e+2) * d, so q = n/d < e+2, so q ≤ e+1 + have h3 := (Nat.div_lt_iff_lt_mul hd_pos).mpr h2 + omega + -- (e+1)² ≤ 2*r_hi*H since e+1 ≤ H ≤ 2*r_hi + have he_lt_H : natSqrt x - r_hi * H + 1 ≤ H := by omega + have hH_le_2r : H ≤ 2 * r_hi := by omega + have h_sq_bound : (natSqrt x - r_hi * H + 1) * (natSqrt x - r_hi * H + 1) ≤ + 2 * r_hi * H := + calc (natSqrt x - r_hi * H + 1) * (natSqrt x - r_hi * H + 1) + ≤ (natSqrt x - r_hi * H + 1) * H := Nat.mul_le_mul_left _ he_lt_H + _ ≤ H * H := Nat.mul_le_mul_right _ he_lt_H + _ ≤ 2 * r_hi * H := Nat.mul_le_mul_right _ hH_le_2r + -- Rearrange RHS: (e+2)*(2*r_hi)*H = 2*(r_hi*H)*(e+1) + 2*r_hi*H + have h_rhs : (natSqrt x - r_hi * H + 1 + 1) * (2 * r_hi) * H = + 2 * (r_hi * H) * (natSqrt x - r_hi * H + 1) + 2 * r_hi * H := by + have : (natSqrt x - r_hi * H + 1 + 1) * (2 * r_hi) = + 2 * r_hi * (natSqrt x - r_hi * H + 1) + 2 * r_hi := by + rw [show natSqrt x - r_hi * H + 1 + 1 = (natSqrt x - r_hi * H + 1) + 1 from rfl, + Nat.add_mul, Nat.one_mul, Nat.mul_comm] + rw [this, Nat.add_mul, (h2rhe (natSqrt x - r_hi * H + 1)).symm] + rw [h_rhs] + -- Goal: n*H < 2*(r_hi*H)*(e+1) + 2*r_hi*H + -- From h_sq_hi and h_decomp: n*H + x_lo_lo < 2*(r_hi*H)*(e+1) + (e+1)² + -- And (e+1)² ≤ 2*r_hi*H (h_sq_bound) + -- So n*H < 2*(r_hi*H)*(e+1) + 2*r_hi*H + omega + constructor + · -- natSqrt x ≤ r = r_hi * H + q + show natSqrt x ≤ r_hi * H + q; omega + · -- r = r_hi * H + q ≤ natSqrt x + 1 + show r_hi * H + q ≤ natSqrt x + 1; omega diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Normalization.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Normalization.lean new file mode 100644 index 000000000..583c3b95a --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Normalization.lean @@ -0,0 +1,129 @@ +/- + Normalization lemma for 512-bit square root. + + Main theorem: natSqrt(x * 4^k) / 2^k = natSqrt(x) + + Justifies the even-shift normalize / un-normalize pattern in 512Math._sqrt. +-/ +import SqrtProof.SqrtCorrect + +-- ============================================================================ +-- Part 1: Uniqueness of floor sqrt +-- ============================================================================ + +/-- If m^2 <= n < (m+1)^2, then natSqrt n = m. -/ +theorem natSqrt_unique (n m : Nat) + (hlo : m * m ≤ n) (hhi : n < (m + 1) * (m + 1)) : + natSqrt n = m := by + have ⟨hrlo, hrhi⟩ := natSqrt_spec n + have hmr : m ≤ natSqrt n := by + suffices h : ¬(natSqrt n < m) by omega + intro h + have h1 : natSqrt n + 1 ≤ m := h + have h2 := Nat.mul_le_mul h1 h1 + omega + have hrm : natSqrt n ≤ m := by + suffices h : ¬(m < natSqrt n) by omega + intro h + have h1 : m + 1 ≤ natSqrt n := h + have h2 := Nat.mul_le_mul h1 h1 + omega + omega + +-- ============================================================================ +-- Part 2: Bracket for natSqrt of scaled value +-- ============================================================================ + +private theorem mul_sq (a b : Nat) : (a * b) * (a * b) = (a * a) * (b * b) := by + calc (a * b) * (a * b) + = a * (b * (a * b)) := by rw [Nat.mul_assoc] + _ = a * (a * (b * b)) := by rw [Nat.mul_left_comm b a b] + _ = (a * a) * (b * b) := by rw [Nat.mul_assoc] + +theorem natSqrt_mul_sq_lower (x c : Nat) : + natSqrt x * c ≤ natSqrt (x * (c * c)) := by + by_cases hc : c = 0 + · simp [hc] + · have hsq : (natSqrt x * c) * (natSqrt x * c) ≤ x * (c * c) := by + rw [mul_sq]; exact Nat.mul_le_mul_right _ (natSqrt_sq_le x) + suffices h : ¬(natSqrt (x * (c * c)) < natSqrt x * c) by omega + intro h + have h1 : natSqrt (x * (c * c)) + 1 ≤ natSqrt x * c := h + have h2 := Nat.mul_le_mul h1 h1 + have h3 := natSqrt_lt_succ_sq (x * (c * c)) + omega + +theorem natSqrt_mul_sq_upper (x c : Nat) (hc : 0 < c) : + natSqrt (x * (c * c)) < (natSqrt x + 1) * c := by + have hsq : x * (c * c) < ((natSqrt x + 1) * c) * ((natSqrt x + 1) * c) := by + rw [mul_sq] + exact Nat.mul_lt_mul_of_pos_right (natSqrt_lt_succ_sq x) (Nat.mul_pos hc hc) + suffices h : ¬((natSqrt x + 1) * c ≤ natSqrt (x * (c * c))) by omega + intro h + have h2 := Nat.mul_le_mul h h + have h3 := natSqrt_sq_le (x * (c * c)) + omega + +-- ============================================================================ +-- Part 3: Division theorem +-- ============================================================================ + +private theorem four_pow_eq (k : Nat) : 4 ^ k = 2 ^ k * 2 ^ k := by + have : (4 : Nat) = 2 ^ 2 := by decide + rw [this, ← Nat.pow_mul, ← Nat.pow_add] + congr 1; omega + +/-- natSqrt(x * 4^k) / 2^k = natSqrt(x). -/ +theorem natSqrt_shift_div (x k : Nat) : + natSqrt (x * 4 ^ k) / 2 ^ k = natSqrt x := by + by_cases hk : k = 0 + · simp [hk] + · have hpow : 0 < 2 ^ k := Nat.two_pow_pos k + rw [four_pow_eq] + have hlo := natSqrt_mul_sq_lower x (2 ^ k) + have hhi := natSqrt_mul_sq_upper x (2 ^ k) hpow + have h1 : natSqrt x ≤ natSqrt (x * (2 ^ k * 2 ^ k)) / 2 ^ k := by + rw [Nat.le_div_iff_mul_le hpow] + exact hlo + have h2 : natSqrt (x * (2 ^ k * 2 ^ k)) / 2 ^ k < natSqrt x + 1 := by + rw [Nat.div_lt_iff_lt_mul hpow] + -- Need: natSqrt(x * (2^k * 2^k)) < (natSqrt x + 1) * 2^k + -- hhi says exactly this + exact hhi + omega + +-- ============================================================================ +-- Part 4: Shift-range lemma +-- ============================================================================ + +private theorem four_pow_eq_two_pow (shift : Nat) : 4 ^ shift = 2 ^ (2 * shift) := by + have : (4 : Nat) = 2 ^ 2 := by decide + rw [this, ← Nat.pow_mul] + +/-- After normalization, x_hi * 4^shift in [2^254, 2^256). -/ +theorem shift_range (x_hi : Nat) (hlo : 0 < x_hi) (hhi : x_hi < 2 ^ 256) : + let shift := (255 - Nat.log2 x_hi) / 2 + 2 ^ 254 ≤ x_hi * 4 ^ shift ∧ x_hi * 4 ^ shift < 2 ^ 256 := by + intro shift + have hne : x_hi ≠ 0 := Nat.ne_of_gt hlo + have hlog := (Nat.log2_eq_iff hne).1 rfl + have hL : Nat.log2 x_hi ≤ 255 := by + have := (Nat.log2_lt hne).2 hhi; omega + have h2shift : 2 * shift ≤ 255 - Nat.log2 x_hi := Nat.mul_div_le (255 - Nat.log2 x_hi) 2 + have h2shift_lb : 255 - Nat.log2 x_hi < 2 * shift + 2 := by + have h := Nat.div_add_mod (255 - Nat.log2 x_hi) 2 + have hmod : (255 - Nat.log2 x_hi) % 2 < 2 := Nat.mod_lt _ (by omega) + omega + rw [four_pow_eq_two_pow] + constructor + · calc 2 ^ 254 + ≤ 2 ^ (Nat.log2 x_hi + 2 * shift) := by + apply Nat.pow_le_pow_right (by omega : 1 ≤ 2); omega + _ = 2 ^ (Nat.log2 x_hi) * 2 ^ (2 * shift) := by rw [Nat.pow_add] + _ ≤ x_hi * 2 ^ (2 * shift) := Nat.mul_le_mul_right _ hlog.1 + · calc x_hi * 2 ^ (2 * shift) + < 2 ^ (Nat.log2 x_hi + 1) * 2 ^ (2 * shift) := + Nat.mul_lt_mul_of_pos_right hlog.2 (Nat.two_pow_pos _) + _ = 2 ^ (Nat.log2 x_hi + 1 + 2 * shift) := by + rw [← Nat.pow_add] + _ ≤ 2 ^ 256 := Nat.pow_le_pow_right (by omega : 1 ≤ 2) (by omega) diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/OsqrtUpSpec.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/OsqrtUpSpec.lean new file mode 100644 index 000000000..8bec1569f --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/OsqrtUpSpec.lean @@ -0,0 +1,436 @@ +/- + Bridge proof: model_osqrtUp_evm computes sqrtUp512. + + The auto-generated model_osqrtUp_evm returns (r_hi, r_lo) where + r_hi * 2^256 + r_lo = sqrtUp512(x_hi * 2^256 + x_lo). + + Case x_hi = 0: r_hi = 0, r_lo = inlined 256-bit sqrtUp(x_lo). + Case x_hi > 0: r = floor_sqrt(x), needsUp = (x > r²), result = r + needsUp with carry. +-/ +import Sqrt512Proof.GeneratedSqrt512Model +import Sqrt512Proof.GeneratedSqrt512Spec +import Sqrt512Proof.SqrtUpCorrect +import Sqrt512Proof.SqrtWrapperSpec +import SqrtProof.GeneratedSqrtModel +import SqrtProof.GeneratedSqrtSpec +import SqrtProof.SqrtCorrect + +namespace Sqrt512Spec + +open Sqrt512GeneratedModel + +-- ============================================================================ +-- Section 1: x_hi = 0 branch — bridge to model_sqrt_up_evm +-- ============================================================================ + +/-- When x_hi = 0, the first component (r_hi) is 0. -/ +private theorem osqrtUp_zero_fst (x_lo : Nat) : + (model_osqrtUp_evm 0 x_lo).1 = 0 := by + simp only [model_osqrtUp_evm] + simp only [evmEq_compat, u256_compat, su256_zero] + simp only [SqrtGeneratedModel.evmEq, SqrtGeneratedModel.u256, SqrtGeneratedModel.WORD_MOD] + simp (config := { decide := true }) + +/-- When x_hi = 0, the second component (r_lo) equals model_sqrt_up_evm x_lo. -/ +private theorem osqrtUp_zero_snd (x_lo : Nat) : + (model_osqrtUp_evm 0 x_lo).2 = SqrtGeneratedModel.model_sqrt_up_evm x_lo := by + simp only [model_osqrtUp_evm, model_sqrt256_up_evm, + SqrtGeneratedModel.model_sqrt_up_evm, SqrtGeneratedModel.model_sqrt_evm] + simp only [evmEq_compat, evmShr_compat, evmAdd_compat, evmDiv_compat, + evmSub_compat, evmClz_compat, evmShl_compat, evmLt_compat, + evmMul_compat, evmGt_compat, u256_compat] + simp only [su256_zero, su256_idem] + simp only [SqrtGeneratedModel.evmEq, SqrtGeneratedModel.u256, SqrtGeneratedModel.WORD_MOD] + simp (config := { decide := true }) + +/-- Ceiling sqrt uniqueness: if x ≤ r² and r is minimal, then r = sqrtUp512 x. -/ +private theorem sqrtUp512_unique (x r : Nat) (hx : x < 2 ^ 512) + (hle : x ≤ r * r) (hmin : ∀ y, x ≤ y * y → r ≤ y) : + r = sqrtUp512 x := by + have ⟨hup_le, hup_min⟩ := sqrtUp512_correct x hx + have h1 := hmin (sqrtUp512 x) hup_le + have h2 := hup_min r hle + omega + +-- ============================================================================ +-- Section 2: Helper lemmas for x_hi > 0 — _mul correctness +-- ============================================================================ + +/-- mulmod(r, r, 2^256-1) combined with mul(r,r) and sub/lt recovers r²/2^256. + Key identity: 2^256 ≡ 1 (mod 2^256-1). -/ +private theorem mul512_high_word (r : Nat) (hr : r < WORD_MOD) : + let mm := evmMulmod r r (evmNot 0) + let m := evmMul r r + evmSub (evmSub mm m) (evmLt mm m) = r * r / WORD_MOD := by + simp only + -- Step 1: Simplify evmNot 0 + have hNot0 : evmNot 0 = WORD_MOD - 1 := by + unfold evmNot u256 WORD_MOD; simp + -- Step 2: Simplify evmMulmod and evmMul + have hWM1_pos : (0 : Nat) < WORD_MOD - 1 := by unfold WORD_MOD; omega + have hWM1_lt : WORD_MOD - 1 < WORD_MOD := by unfold WORD_MOD; omega + have hmm : evmMulmod r r (evmNot 0) = (r * r) % (WORD_MOD - 1) := by + unfold evmMulmod + simp only [u256_id' r hr, hNot0, u256_id' (WORD_MOD - 1) hWM1_lt] + simp [Nat.ne_of_gt hWM1_pos] + have hm : evmMul r r = (r * r) % WORD_MOD := by + unfold evmMul u256; simp [Nat.mod_eq_of_lt hr] + rw [hmm, hm] + -- Abbreviate for readability (without set tactic) + -- hi = (r*r) % (WORD_MOD - 1), lo = (r*r) % WORD_MOD, q = r*r / WORD_MOD + have hdecomp : r * r = r * r / WORD_MOD * WORD_MOD + r * r % WORD_MOD := by + have := Nat.div_add_mod (r * r) WORD_MOD + rw [Nat.mul_comm] at this; omega + have hq_bound : r * r / WORD_MOD < WORD_MOD := by + have : r * r < WORD_MOD * WORD_MOD := + Nat.mul_lt_mul_of_le_of_lt (Nat.le_of_lt hr) hr (by unfold WORD_MOD; omega) + exact Nat.div_lt_of_lt_mul this + have hlo_bound : r * r % WORD_MOD < WORD_MOD := Nat.mod_lt _ (by unfold WORD_MOD; omega) + -- Key congruence: hi = (q + lo) % (WORD_MOD - 1) + -- Since WORD_MOD = (WORD_MOD - 1) + 1, q*WORD_MOD = q*(WORD_MOD-1) + q + -- So r*r = q*(WORD_MOD-1) + (q + lo), and r*r % (WORD_MOD-1) = (q+lo) % (WORD_MOD-1) + have hhi_eq : (r * r) % (WORD_MOD - 1) = (r * r / WORD_MOD + r * r % WORD_MOD) % (WORD_MOD - 1) := by + -- q * W = (W-1)*q + q + have hqW : r * r / WORD_MOD * WORD_MOD = + (WORD_MOD - 1) * (r * r / WORD_MOD) + r * r / WORD_MOD := by + have hsc := Nat.sub_add_cancel (Nat.one_le_of_lt (show 1 < WORD_MOD from by unfold WORD_MOD; omega)) + -- q * ((W-1) + 1) = q*(W-1) + q*1 + have h := Nat.mul_add (r * r / WORD_MOD) (WORD_MOD - 1) 1 + rw [hsc, Nat.mul_one] at h + -- h : r * r / WORD_MOD * WORD_MOD = r * r / WORD_MOD * (WORD_MOD - 1) + r * r / WORD_MOD + rw [h, Nat.mul_comm (r * r / WORD_MOD) (WORD_MOD - 1)] + -- r*r = (W-1)*q + (q+lo) + have hrr_eq : r * r = (WORD_MOD - 1) * (r * r / WORD_MOD) + (r * r / WORD_MOD + r * r % WORD_MOD) := by + omega + -- Apply Nat.mul_add_mod: ((W-1)*q + (q+lo)) % (W-1) = (q+lo) % (W-1) + have step := Nat.mul_add_mod (WORD_MOD - 1) (r * r / WORD_MOD) (r * r / WORD_MOD + r * r % WORD_MOD) + -- step : ((WORD_MOD - 1) * (r * r / WORD_MOD) + (r * r / WORD_MOD + r * r % WORD_MOD)) % (WORD_MOD - 1) = + -- (r * r / WORD_MOD + r * r % WORD_MOD) % (WORD_MOD - 1) + rw [← hrr_eq] at step; exact step + have hhi_bound : (r * r) % (WORD_MOD - 1) < WORD_MOD - 1 := Nat.mod_lt _ hWM1_pos + -- Case split on whether q + lo wraps modulo (WORD_MOD - 1) + by_cases hcase : r * r / WORD_MOD + r * r % WORD_MOD < WORD_MOD - 1 + · -- Case 1: no wrap + have hhi_val : (r * r) % (WORD_MOD - 1) = r * r / WORD_MOD + r * r % WORD_MOD := by + rw [hhi_eq, Nat.mod_eq_of_lt hcase] + have hhi_wm : (r * r) % (WORD_MOD - 1) < WORD_MOD := by omega + have hge : r * r % WORD_MOD ≤ (r * r) % (WORD_MOD - 1) := by + rw [hhi_val]; exact Nat.le_add_left _ _ + have hlt_eq : evmLt ((r * r) % (WORD_MOD - 1)) (r * r % WORD_MOD) = 0 := by + unfold evmLt u256 + simp only [Nat.mod_eq_of_lt hhi_wm, Nat.mod_eq_of_lt hlo_bound] + exact if_neg (Nat.not_lt.mpr hge) + rw [hlt_eq] + have hsub1 : evmSub ((r * r) % (WORD_MOD - 1)) (r * r % WORD_MOD) = + (r * r) % (WORD_MOD - 1) - r * r % WORD_MOD := + evmSub_eq_of_le _ _ hhi_wm hge + rw [hsub1] + have hq_eq : (r * r) % (WORD_MOD - 1) - r * r % WORD_MOD = r * r / WORD_MOD := by + omega + rw [hq_eq] + -- evmSub q 0 = q + exact evmSub_eq_of_le _ 0 hq_bound (Nat.zero_le _) + · -- Case 2: wrap (hcase : ¬(q + lo < W-1), i.e., W-1 ≤ q + lo) + have hcase' : WORD_MOD - 1 ≤ r * r / WORD_MOD + r * r % WORD_MOD := Nat.not_lt.mp hcase + -- r * r ≤ (WORD_MOD-1)^2 since r < WORD_MOD + -- q + lo < 2*(WORD_MOD-1) because q ≤ WORD_MOD-2 and lo ≤ WORD_MOD-1 + -- r ≤ WORD_MOD - 1, so r*r ≤ (WORD_MOD-1)^2, so q = r*r/WORD_MOD ≤ WORD_MOD - 2 + have hq_le : r * r / WORD_MOD ≤ WORD_MOD - 2 := by + have hr' : r ≤ WORD_MOD - 1 := by omega + have hrsq : r * r ≤ (WORD_MOD - 1) * (WORD_MOD - 1) := Nat.mul_le_mul hr' hr' + have h1 : r * r / WORD_MOD ≤ (WORD_MOD - 1) * (WORD_MOD - 1) / WORD_MOD := + @Nat.div_le_div_right _ _ WORD_MOD hrsq + suffices h : (WORD_MOD - 1) * (WORD_MOD - 1) / WORD_MOD = WORD_MOD - 2 by omega + unfold WORD_MOD; omega + have hql_lt : r * r / WORD_MOD + r * r % WORD_MOD < 2 * (WORD_MOD - 1) := by omega + have hhi_val : (r * r) % (WORD_MOD - 1) = + r * r / WORD_MOD + r * r % WORD_MOD - (WORD_MOD - 1) := by + rw [hhi_eq, + Nat.mod_eq_sub_mod hcase', + Nat.mod_eq_of_lt (by omega)] + have hlt_lo : (r * r) % (WORD_MOD - 1) < r * r % WORD_MOD := by + rw [hhi_val]; omega + have hhi_wm : (r * r) % (WORD_MOD - 1) < WORD_MOD := by omega + have hlt_eq : evmLt ((r * r) % (WORD_MOD - 1)) (r * r % WORD_MOD) = 1 := by + unfold evmLt u256 + simp [Nat.mod_eq_of_lt hhi_wm, Nat.mod_eq_of_lt hlo_bound] + exact hlt_lo + rw [hlt_eq] + -- evmSub wraps: hi + WORD_MOD - lo + have hsub1 : evmSub ((r * r) % (WORD_MOD - 1)) (r * r % WORD_MOD) = + (r * r) % (WORD_MOD - 1) + WORD_MOD - r * r % WORD_MOD := by + unfold evmSub u256 + simp [Nat.mod_eq_of_lt hhi_wm, Nat.mod_eq_of_lt hlo_bound] + exact Nat.mod_eq_of_lt (show (r * r) % (WORD_MOD - 1) + WORD_MOD - r * r % WORD_MOD < WORD_MOD + by rw [hhi_val]; omega) + rw [hsub1] + have hval : (r * r) % (WORD_MOD - 1) + WORD_MOD - r * r % WORD_MOD < WORD_MOD := by + rw [hhi_val]; omega + have hsub2 : evmSub ((r * r) % (WORD_MOD - 1) + WORD_MOD - r * r % WORD_MOD) 1 = + (r * r) % (WORD_MOD - 1) + WORD_MOD - r * r % WORD_MOD - 1 := + evmSub_eq_of_le _ 1 hval (by rw [hhi_val]; omega) + rw [hsub2] + rw [hhi_val]; omega + +/-- mul(r, r) gives the low word of r². -/ +private theorem mul512_low_word (r : Nat) (hr : r < WORD_MOD) : + evmMul r r = r * r % WORD_MOD := by + unfold evmMul u256; simp [Nat.mod_eq_of_lt hr] + +-- ============================================================================ +-- Section 3: Helper lemmas for x_hi > 0 — _gt correctness +-- ============================================================================ + +/-- The 512-bit lexicographic comparison correctly computes x > r². -/ +private theorem gt512_correct (x_hi x_lo sq_hi sq_lo : Nat) + (hxhi : x_hi < WORD_MOD) (hxlo : x_lo < WORD_MOD) + (hsqhi : sq_hi < WORD_MOD) (hsqlo : sq_lo < WORD_MOD) : + let cmp := evmOr (evmGt x_hi sq_hi) + (evmAnd (evmEq x_hi sq_hi) (evmGt x_lo sq_lo)) + (cmp ≠ 0) ↔ (x_hi * WORD_MOD + x_lo > sq_hi * WORD_MOD + sq_lo) := by + simp only + -- Simplify EVM operations to pure comparisons + have hgt_hi : evmGt x_hi sq_hi = if x_hi > sq_hi then 1 else 0 := by + unfold evmGt u256; simp [Nat.mod_eq_of_lt hxhi, Nat.mod_eq_of_lt hsqhi] + have heq_hi : evmEq x_hi sq_hi = if x_hi = sq_hi then 1 else 0 := by + unfold evmEq u256; simp [Nat.mod_eq_of_lt hxhi, Nat.mod_eq_of_lt hsqhi] + have hgt_lo : evmGt x_lo sq_lo = if x_lo > sq_lo then 1 else 0 := by + unfold evmGt u256; simp [Nat.mod_eq_of_lt hxlo, Nat.mod_eq_of_lt hsqlo] + rw [hgt_hi, heq_hi, hgt_lo] + -- Full case analysis on orderings + by_cases hgt : x_hi > sq_hi + · -- x_hi > sq_hi: LHS or has at least one 1 + have hneq : ¬(x_hi = sq_hi) := by omega + simp only [hgt, ite_true, hneq, ite_false] + -- evmOr 1 (evmAnd 0 (if ...)) always reduces to something nonzero + have hor_nz : ∀ v, evmOr 1 (evmAnd 0 v) ≠ 0 := by + intro v; unfold evmOr evmAnd u256 WORD_MOD; simp (config := { decide := true }) + constructor + · intro _ + -- sq_hi + 1 ≤ x_hi, so sq_hi*W + W ≤ x_hi*W + have h1 : sq_hi * WORD_MOD + WORD_MOD ≤ x_hi * WORD_MOD := by + have := Nat.mul_le_mul_right WORD_MOD hgt + rwa [Nat.succ_mul] at this + omega + · intro _; exact hor_nz _ + · by_cases heq : x_hi = sq_hi + · subst heq + simp only [Nat.lt_irrefl, ite_false, ite_true] + by_cases hgtlo : x_lo > sq_lo + · simp only [hgtlo, ite_true] + constructor + · intro _; omega + · intro _; unfold evmOr evmAnd u256 WORD_MOD; simp (config := { decide := true }) + · simp only [hgtlo, ite_false] + have hor_z : evmOr 0 (evmAnd 1 0) = 0 := by + unfold evmOr evmAnd u256 WORD_MOD; simp (config := { decide := true }) + constructor + · intro h; exact absurd hor_z h + · intro h; omega + · -- x_hi < sq_hi + have hlt : x_hi < sq_hi := by omega + have hng : ¬(x_hi > sq_hi) := by omega + simp only [hng, ite_false, heq, ite_false] + -- evmOr 0 (evmAnd 0 (if ...)) = 0 + have hor_z : ∀ v, evmOr 0 (evmAnd 0 v) = 0 := by + intro v; unfold evmOr evmAnd u256 WORD_MOD; simp (config := { decide := true }) + constructor + · intro h; exact absurd (hor_z _) h + · intro h + have h1 : x_hi * WORD_MOD + WORD_MOD ≤ sq_hi * WORD_MOD := by + have := Nat.mul_le_mul_right WORD_MOD hlt + rwa [Nat.succ_mul] at this + omega + +-- ============================================================================ +-- Section 4: Helper lemmas for x_hi > 0 — _add correctness +-- ============================================================================ + +/-- add(r, needsUp) with carry detection gives correct 512-bit result. + When needsUp ∈ {0,1}, the result r + needsUp is at most 2^256. -/ +private theorem add_with_carry (r needsUp : Nat) (hr : r < WORD_MOD) + (hn : needsUp = 0 ∨ needsUp = 1) : + let r_lo := evmAdd r needsUp + let r_hi := evmLt (evmAdd r needsUp) r + r_hi * WORD_MOD + r_lo = r + needsUp := by + simp only + have hn_bound : needsUp < WORD_MOD := by rcases hn with h | h <;> (rw [h]; unfold WORD_MOD; omega) + by_cases hov : r + needsUp < WORD_MOD + · -- No overflow + have hadd : evmAdd r needsUp = r + needsUp := + evmAdd_eq' r needsUp hr hn_bound hov + rw [hadd] + have hge : r ≤ r + needsUp := Nat.le_add_right r needsUp + have hlt_eq : evmLt (r + needsUp) r = 0 := by + unfold evmLt u256 + simp only [Nat.mod_eq_of_lt hov, Nat.mod_eq_of_lt hr] + exact if_neg (Nat.not_lt.mpr hge) + rw [hlt_eq]; simp + · -- Overflow: r + needsUp ≥ WORD_MOD, so needsUp = 1 and r = WORD_MOD - 1 + have hov' : WORD_MOD ≤ r + needsUp := Nat.not_lt.mp hov + have hn1 : needsUp = 1 := by rcases hn with h | h <;> omega + subst hn1 + have hr_max : r = WORD_MOD - 1 := by omega + subst hr_max + -- evmAdd (WORD_MOD - 1) 1 = 0 (overflow) + have hadd : evmAdd (WORD_MOD - 1) 1 = 0 := by + unfold evmAdd u256 WORD_MOD; simp + rw [hadd] + -- evmLt 0 (WORD_MOD - 1) = 1 (since 0 < WORD_MOD - 1) + have hlt_eq : evmLt 0 (WORD_MOD - 1) = 1 := by + unfold evmLt u256 WORD_MOD; simp + rw [hlt_eq] + unfold WORD_MOD; omega + +-- ============================================================================ +-- Section 5: Main theorem — model_osqrtUp_evm = sqrtUp512 +-- ============================================================================ + +set_option exponentiation.threshold 1024 in +/-- The EVM model of osqrtUp(uint512, uint512) computes sqrtUp512. -/ +theorem model_osqrtUp_evm_correct (x_hi x_lo : Nat) + (hxhi : x_hi < 2 ^ 256) (hxlo : x_lo < 2 ^ 256) : + let (r_hi, r_lo) := model_osqrtUp_evm x_hi x_lo + let x := x_hi * 2 ^ 256 + x_lo + r_hi * 2 ^ 256 + r_lo = sqrtUp512 x := by + simp only + by_cases hxhi0 : x_hi = 0 + · -- x_hi = 0: use 256-bit ceiling sqrt bridge + subst hxhi0 + simp only [Nat.zero_mul, Nat.zero_add] + rw [osqrtUp_zero_fst, osqrtUp_zero_snd] + simp only [Nat.zero_mul, Nat.zero_add] + -- model_sqrt_up_evm x_lo satisfies ceiling sqrt spec + have hspec := SqrtGeneratedModel.model_sqrt_up_evm_ceil_u256 x_lo hxlo + -- sqrtUp512 x_lo also satisfies it (for x_lo < 2^256 < 2^512) + have hx512 : x_lo < 2 ^ 512 := by + calc x_lo < 2 ^ 256 := hxlo + _ ≤ 2 ^ 512 := Nat.pow_le_pow_right (by omega) (by omega) + -- Both satisfy the same uniqueness property + exact sqrtUp512_unique x_lo (SqrtGeneratedModel.model_sqrt_up_evm x_lo) hx512 + hspec.1 hspec.2 + · -- x_hi > 0: floor sqrt + carry + have hxhi_pos : 0 < x_hi := Nat.pos_of_ne_zero hxhi0 + -- Convert 2^256 to WORD_MOD for local use + have hWM : WORD_MOD = 2 ^ 256 := rfl + have hr_wm : x_hi < WORD_MOD := by rwa [hWM] + have hlo_wm : x_lo < WORD_MOD := by rwa [hWM] + -- Unfold model and simplify u256 on valid inputs + unfold model_osqrtUp_evm + have hxhi_u : u256 x_hi = x_hi := u256_id' x_hi hr_wm + have hxlo_u : u256 x_lo = x_lo := u256_id' x_lo hlo_wm + simp only [hxhi_u, hxlo_u] + -- Evaluate evmEq x_hi 0 = 0 (since x_hi > 0) + have hneq : evmEq x_hi 0 = 0 := by + unfold evmEq; simp [u256_id' x_hi hr_wm]; exact Nat.ne_of_gt hxhi_pos + -- Simplify: evmEq x_hi 0 = 0, then (0 ≠ 0) is decidably False, take else branches + have h0eq0 : ¬((0 : Nat) ≠ 0) := by omega + simp only [hneq, if_neg h0eq0] + -- Abbreviate r = model_sqrt512_evm x_hi x_lo + -- First establish r = natSqrt(x) and r < WORD_MOD + have hr_eq : model_sqrt512_evm x_hi x_lo = natSqrt (x_hi * 2 ^ 256 + x_lo) := + model_sqrt512_evm_correct x_hi x_lo hxhi_pos hxhi hxlo + -- natSqrt(x) < 2^256 because x < 2^512 so natSqrt(x) < 2^256 + have hx_lt : x_hi * 2 ^ 256 + x_lo < 2 ^ 512 := by + calc x_hi * 2 ^ 256 + x_lo + < 2 ^ 256 * 2 ^ 256 := by + have := Nat.mul_lt_mul_of_pos_right hxhi (Nat.two_pow_pos 256) + omega + _ = 2 ^ 512 := by rw [← Nat.pow_add] + have hnatSqrt_bound : natSqrt (x_hi * 2 ^ 256 + x_lo) < 2 ^ 256 := by + suffices h : ¬(2 ^ 256 ≤ natSqrt (x_hi * 2 ^ 256 + x_lo)) by omega + intro h + have h2 := Nat.mul_le_mul h h + have h3 := natSqrt_sq_le (x_hi * 2 ^ 256 + x_lo) + have : 2 ^ 256 * 2 ^ 256 = 2 ^ 512 := by rw [← Nat.pow_add] + omega + have hr_wm' : model_sqrt512_evm x_hi x_lo < WORD_MOD := by + rw [hr_eq, hWM]; exact hnatSqrt_bound + -- Generalize model_sqrt512_evm x_hi x_lo = r + generalize hgen : model_sqrt512_evm x_hi x_lo = r at * + -- Rewrite sq_hi and sq_lo using mul512_high_word and mul512_low_word + rw [mul512_high_word r hr_wm', mul512_low_word r hr_wm'] + -- Establish bounds for sq_hi and sq_lo + have hsqhi_bound : r * r / WORD_MOD < WORD_MOD := by + have : r * r < WORD_MOD * WORD_MOD := + Nat.mul_lt_mul_of_le_of_lt (Nat.le_of_lt hr_wm') hr_wm' (by unfold WORD_MOD; omega) + exact Nat.div_lt_of_lt_mul this + have hsqlo_bound : r * r % WORD_MOD < WORD_MOD := Nat.mod_lt _ (by unfold WORD_MOD; omega) + -- Generalize the needsUp expression + generalize hnu_def : evmOr (evmGt x_hi (r * r / WORD_MOD)) + (evmAnd (evmEq x_hi (r * r / WORD_MOD)) (evmGt x_lo (r * r % WORD_MOD))) = needsUp + -- needsUp ∈ {0, 1} + have hnu_01 : needsUp = 0 ∨ needsUp = 1 := by + rw [← hnu_def] + have hgt_01 : ∀ a b : Nat, a < WORD_MOD → b < WORD_MOD → + evmGt a b = 0 ∨ evmGt a b = 1 := by + intro a b ha hb; unfold evmGt + simp only [u256_id' a ha, u256_id' b hb]; by_cases h : a > b <;> simp [h] + have heq_01 : ∀ a b : Nat, a < WORD_MOD → b < WORD_MOD → + evmEq a b = 0 ∨ evmEq a b = 1 := by + intro a b ha hb; unfold evmEq + simp only [u256_id' a ha, u256_id' b hb]; by_cases h : a = b <;> simp [h] + have hand_01 : ∀ a b : Nat, (a = 0 ∨ a = 1) → (b = 0 ∨ b = 1) → + evmAnd a b = 0 ∨ evmAnd a b = 1 := by + intro a b ha hb + rcases ha with rfl | rfl <;> rcases hb with rfl | rfl <;> + (unfold evmAnd u256 WORD_MOD; simp (config := { decide := true })) + have hor_01 : ∀ a b : Nat, (a = 0 ∨ a = 1) → (b = 0 ∨ b = 1) → + evmOr a b = 0 ∨ evmOr a b = 1 := by + intro a b ha hb + rcases ha with rfl | rfl <;> rcases hb with rfl | rfl <;> + (unfold evmOr u256 WORD_MOD; simp (config := { decide := true })) + exact hor_01 _ _ + (hgt_01 x_hi (r * r / WORD_MOD) hr_wm hsqhi_bound) + (hand_01 _ _ + (heq_01 x_hi (r * r / WORD_MOD) hr_wm hsqhi_bound) + (hgt_01 x_lo (r * r % WORD_MOD) hlo_wm hsqlo_bound)) + -- Key semantic fact: needsUp ≠ 0 ↔ x_hi * W + x_lo > r * r + have hnu_iff : (needsUp ≠ 0) ↔ (x_hi * WORD_MOD + x_lo > r * r) := by + rw [← hnu_def] + have h := gt512_correct x_hi x_lo (r * r / WORD_MOD) (r * r % WORD_MOD) + hr_wm hlo_wm hsqhi_bound hsqlo_bound + simp only at h + -- h: (...) ↔ x_hi * WORD_MOD + x_lo > r*r/WORD_MOD * WORD_MOD + r*r % WORD_MOD + -- Nat.div_add_mod gives WORD_MOD * (r*r / WORD_MOD) + ..., need to commute + have hdm : r * r / WORD_MOD * WORD_MOD + r * r % WORD_MOD = r * r := by + rw [Nat.mul_comm]; exact Nat.div_add_mod .. + rw [hdm] at h; exact h + -- Use add_with_carry + have hcarry := add_with_carry r needsUp hr_wm' hnu_01 + simp only at hcarry + -- evmAdd 0 x = x when x ∈ {0, 1} + have hfst_simp : evmAdd 0 (evmLt (evmAdd r needsUp) r) = + evmLt (evmAdd r needsUp) r := by + have hlt_01 : evmLt (evmAdd r needsUp) r = 0 ∨ evmLt (evmAdd r needsUp) r = 1 := by + unfold evmLt; by_cases h : u256 (evmAdd r needsUp) < u256 r <;> simp [h] + rcases hlt_01 with h | h <;> + (rw [h]; unfold evmAdd u256 WORD_MOD; simp (config := { decide := true })) + rw [hfst_simp, ← hWM, hcarry] + -- Goal: r + needsUp = sqrtUp512 (x_hi * WORD_MOD + x_lo) + -- Rewrite hr_eq to use WORD_MOD + have hr_eq_wm : r = natSqrt (x_hi * WORD_MOD + x_lo) := by rw [hr_eq, hWM] + have hx_lt_wm : x_hi * WORD_MOD + x_lo < 2 ^ 512 := by rw [hWM]; exact hx_lt + have hsqrt512_eq : sqrt512 (x_hi * WORD_MOD + x_lo) = + natSqrt (x_hi * WORD_MOD + x_lo) := + sqrt512_correct (x_hi * WORD_MOD + x_lo) hx_lt_wm + unfold sqrtUp512 + simp only + rw [hsqrt512_eq, ← hr_eq_wm] + -- Goal: r + needsUp = if r * r < x_hi * WORD_MOD + x_lo then r + 1 else r + by_cases hlt : r * r < x_hi * WORD_MOD + x_lo + · -- r*r < x: needsUp = 1 + simp only [hlt, ite_true] + have hnu_nz : needsUp ≠ 0 := hnu_iff.mpr hlt + rcases hnu_01 with h | h + · exact absurd h hnu_nz + · rw [h] + · -- r*r ≥ x: needsUp = 0 + simp only [hlt, ite_false] + have hnu_z : needsUp = 0 := by + rcases hnu_01 with h | h + · exact h + · exfalso; have := hnu_iff.mp (by rw [h]; omega); omega + rw [hnu_z]; omega + +end Sqrt512Spec diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Sqrt512Correct.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Sqrt512Correct.lean new file mode 100644 index 000000000..a77925a1e --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/Sqrt512Correct.lean @@ -0,0 +1,193 @@ +/- + End-to-end correctness of 512-bit square root. + + Composes normalization, Karatsuba step, correction, and un-normalization: + + sqrt512(x) = natSqrt(x) for x < 2^512 +-/ +import SqrtProof.SqrtCorrect +import Sqrt512Proof.Normalization +import Sqrt512Proof.KaratsubaStep +import Sqrt512Proof.Correction + +-- ============================================================================ +-- Part 1: Karatsuba uncorrected and corrected +-- ============================================================================ + +/-- Uncorrected Karatsuba result (before correction step). -/ +noncomputable def karatsubaR (x_hi x_lo : Nat) : Nat := + let H := 2 ^ 128 + let r_hi := natSqrt x_hi + let res := x_hi - r_hi * r_hi + let x_lo_hi := x_lo / H + let n := res * H + x_lo_hi + let d := 2 * r_hi + let q := n / d + r_hi * H + q + +/-- The full Karatsuba floor sqrt with correction. -/ +noncomputable def karatsubaFloor (x_hi x_lo : Nat) : Nat := + let H := 2 ^ 128 + let x_lo_hi := x_lo / H + let x_lo_lo := x_lo % H + let x := x_hi * (H * H) + x_lo_hi * H + x_lo_lo + let r := karatsubaR x_hi x_lo + if x < r * r then r - 1 else r + +/-- karatsubaR satisfies the Karatsuba bracket for normalized inputs. -/ +theorem karatsubaR_bracket (x_hi x_lo : Nat) + (hxhi_lo : 2 ^ 254 ≤ x_hi) + (hxlo : x_lo < 2 ^ 256) : + let H := 2 ^ 128 + let x_lo_hi := x_lo / H + let x_lo_lo := x_lo % H + let x := x_hi * (H * H) + x_lo_hi * H + x_lo_lo + let r := karatsubaR x_hi x_lo + natSqrt x ≤ r ∧ r ≤ natSqrt x + 1 := by + simp only + have h128sq : (2 : Nat) ^ 128 * 2 ^ 128 = 2 ^ 256 := by rw [← Nat.pow_add] + exact karatsuba_bracket_512 x_hi (x_lo / 2 ^ 128) (x_lo % 2 ^ 128) + hxhi_lo + (Nat.div_lt_of_lt_mul (by rwa [h128sq])) + (Nat.mod_lt x_lo (Nat.two_pow_pos 128)) + +/-- karatsubaFloor = natSqrt for normalized inputs. -/ +theorem karatsubaFloor_eq_natSqrt (x_hi x_lo : Nat) + (hxhi_lo : 2 ^ 254 ≤ x_hi) + (hxlo : x_lo < 2 ^ 256) : + karatsubaFloor x_hi x_lo = natSqrt (x_hi * 2 ^ 256 + x_lo) := by + have hHsq : (2 : Nat) ^ 128 * ((2 : Nat) ^ 128) = (2 : Nat) ^ 256 := by rw [← Nat.pow_add] + have hxlo_decomp : x_lo = x_lo / (2 : Nat) ^ 128 * (2 : Nat) ^ 128 + x_lo % (2 : Nat) ^ 128 := by + have := (Nat.div_add_mod x_lo ((2 : Nat) ^ 128)).symm + rw [Nat.mul_comm] at this; exact this + + have hx_eq : x_hi * 2 ^ 256 + x_lo = + x_hi * ((2 : Nat) ^ 128 * (2 : Nat) ^ 128) + + x_lo / (2 : Nat) ^ 128 * (2 : Nat) ^ 128 + x_lo % (2 : Nat) ^ 128 := by + rw [← hHsq, hxlo_decomp]; omega + + have hbracket := karatsubaR_bracket x_hi x_lo hxhi_lo hxlo + + rw [hx_eq] + unfold karatsubaFloor + simp only + exact correction_correct + (x_hi * ((2 : Nat) ^ 128 * (2 : Nat) ^ 128) + + x_lo / (2 : Nat) ^ 128 * (2 : Nat) ^ 128 + x_lo % (2 : Nat) ^ 128) + (karatsubaR x_hi x_lo) hbracket.1 hbracket.2 + +-- ============================================================================ +-- Part 2: Normalization bounds +-- ============================================================================ + +private theorem four_pow_eq_two_pow' (shift : Nat) : 4 ^ shift = 2 ^ (2 * shift) := by + have : (4 : Nat) = 2 ^ 2 := by decide + rw [this, ← Nat.pow_mul] + +/-- x * 4^shift < 2^512 when x and shift are properly constrained. -/ +private theorem normalized_lt_512 (x x_hi : Nat) + (hxhi_pos : 0 < x_hi) (hxhi_lt : x_hi < 2 ^ 256) + (hx_lt : x < (x_hi + 1) * 2 ^ 256) : + let shift := (255 - Nat.log2 x_hi) / 2 + x * 4 ^ shift < 2 ^ 512 := by + intro shift + have hne : x_hi ≠ 0 := Nat.ne_of_gt hxhi_pos + have hlog := (Nat.log2_eq_iff hne).1 rfl + have hL : Nat.log2 x_hi ≤ 255 := by have := (Nat.log2_lt hne).2 hxhi_lt; omega + have h2shift : 2 * shift ≤ 255 - Nat.log2 x_hi := Nat.mul_div_le (255 - Nat.log2 x_hi) 2 + rw [four_pow_eq_two_pow'] + calc x * 2 ^ (2 * shift) + < (x_hi + 1) * 2 ^ 256 * 2 ^ (2 * shift) := + Nat.mul_lt_mul_of_pos_right hx_lt (Nat.two_pow_pos _) + _ ≤ 2 ^ (Nat.log2 x_hi + 1) * 2 ^ 256 * 2 ^ (2 * shift) := + Nat.mul_le_mul_right _ (Nat.mul_le_mul_right _ hlog.2) + _ = 2 ^ (Nat.log2 x_hi + 1 + 256 + 2 * shift) := by + rw [← Nat.pow_add, ← Nat.pow_add] + _ ≤ 2 ^ 512 := Nat.pow_le_pow_right (by omega) (by omega) + +/-- The normalized top word is >= 2^254. -/ +private theorem normalized_hi_lower (x x_hi : Nat) + (hxhi_pos : 0 < x_hi) (hxhi_lt : x_hi < 2 ^ 256) + (hx_ge : x_hi * 2 ^ 256 ≤ x) : + let shift := (255 - Nat.log2 x_hi) / 2 + 2 ^ 254 ≤ x * 4 ^ shift / 2 ^ 256 := by + intro shift + have hsr := shift_range x_hi hxhi_pos hxhi_lt + have h1 : x_hi * 2 ^ 256 * 4 ^ shift ≤ x * 4 ^ shift := + Nat.mul_le_mul_right _ hx_ge + have h2 : x_hi * 4 ^ shift ≤ x * 4 ^ shift / 2 ^ 256 := by + rw [Nat.le_div_iff_mul_le (Nat.two_pow_pos 256)] + calc x_hi * 4 ^ shift * 2 ^ 256 + = x_hi * 2 ^ 256 * 4 ^ shift := by + rw [Nat.mul_assoc, Nat.mul_comm (4 ^ shift) (2 ^ 256), ← Nat.mul_assoc] + _ ≤ x * 4 ^ shift := h1 + exact Nat.le_trans hsr.1 h2 + +-- ============================================================================ +-- Part 3: The full 512-bit sqrt +-- ============================================================================ + +/-- 512-bit floor square root (Nat model). -/ +noncomputable def sqrt512 (x : Nat) : Nat := + if x < 2 ^ 256 then + natSqrt x + else + let x_hi := x / 2 ^ 256 + let _x_lo := x % 2 ^ 256 + let shift := (255 - Nat.log2 x_hi) / 2 + let x' := x * 4 ^ shift + karatsubaFloor (x' / 2 ^ 256) (x' % 2 ^ 256) / 2 ^ shift + +set_option exponentiation.threshold 1024 in +/-- sqrt512 is correct for x < 2^512. -/ +theorem sqrt512_correct (x : Nat) (hx : x < 2 ^ 512) : + sqrt512 x = natSqrt x := by + unfold sqrt512 + by_cases hlt : x < 2 ^ 256 + · simp [hlt] + · simp [hlt] + have hge : 2 ^ 256 ≤ x := by omega + have hxhi_pos : 0 < x / 2 ^ 256 := + Nat.div_pos hge (Nat.two_pow_pos 256) + have hxhi_lt : x / 2 ^ 256 < 2 ^ 256 := by + rw [Nat.div_lt_iff_lt_mul (Nat.two_pow_pos 256)] + calc x < 2 ^ 512 := hx + _ = 2 ^ 256 * 2 ^ 256 := by rw [← Nat.pow_add] + have hxlo_bound : x % 2 ^ 256 < 2 ^ 256 := Nat.mod_lt x (Nat.two_pow_pos 256) + have hx_decomp : x = x / 2 ^ 256 * 2 ^ 256 + x % 2 ^ 256 := by + have := (Nat.div_add_mod x (2 ^ 256)).symm + rw [Nat.mul_comm] at this; exact this + + have hx_lt : x < (x / 2 ^ 256 + 1) * 2 ^ 256 := by omega + have hx'_lt := normalized_lt_512 x (x / 2 ^ 256) hxhi_pos hxhi_lt hx_lt + have hx_ge : x / 2 ^ 256 * 2 ^ 256 ≤ x := by omega + have hxhi'_lo := normalized_hi_lower x (x / 2 ^ 256) hxhi_pos hxhi_lt hx_ge + have h256sq : (2 : Nat) ^ 256 * 2 ^ 256 = 2 ^ 512 := by rw [← Nat.pow_add] + have hxhi'_lt : x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) / 2 ^ 256 < 2 ^ 256 := + Nat.div_lt_of_lt_mul (by rwa [h256sq]) + have hxlo'_bound : x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) % 2 ^ 256 < 2 ^ 256 := + Nat.mod_lt _ (Nat.two_pow_pos 256) + + have hkf := karatsubaFloor_eq_natSqrt + (x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) / 2 ^ 256) + (x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) % 2 ^ 256) + hxhi'_lo hxlo'_bound + -- hkf : karatsubaFloor (x'/2^256) (x'%2^256) = natSqrt (x'/2^256 * 2^256 + x'%2^256) + -- We need: karatsubaFloor (x'/2^256) (x'%2^256) / 2^shift = natSqrt x + -- Since x'/2^256 * 2^256 + x'%2^256 = x' (Euclidean decomposition) + -- hkf gives karatsubaFloor ... = natSqrt x' + -- Then natSqrt x' / 2^shift = natSqrt x (by natSqrt_shift_div) + have hx'_eq : x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) / 2 ^ 256 * 2 ^ 256 + + x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) % 2 ^ 256 = + x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2) := by + have := Nat.div_add_mod (x * 4 ^ ((255 - Nat.log2 (x / 2 ^ 256)) / 2)) (2 ^ 256) + rw [Nat.mul_comm] at this; omega + rw [hkf, hx'_eq] + exact natSqrt_shift_div x ((255 - Nat.log2 (x / 2 ^ 256)) / 2) + +/-- sqrt512 satisfies the integer square root spec. -/ +theorem sqrt512_spec (x : Nat) (hx : x < 2 ^ 512) : + let r := sqrt512 x + r * r ≤ x ∧ x < (r + 1) * (r + 1) := by + simp only; rw [sqrt512_correct x hx] + exact ⟨natSqrt_sq_le x, natSqrt_lt_succ_sq x⟩ diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/SqrtUpCorrect.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/SqrtUpCorrect.lean new file mode 100644 index 000000000..67121001a --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/SqrtUpCorrect.lean @@ -0,0 +1,69 @@ +/- + Ceiling square root for 512-bit values. + Models 512Math.osqrtUp (lines 1798-1811). +-/ +import Sqrt512Proof.Sqrt512Correct + +private theorem sq_expand_aux (m : Nat) : + (m + 1) * (m + 1) = m * m + 2 * m + 1 := by + rw [Nat.add_mul, Nat.mul_add, Nat.mul_add, Nat.mul_one, Nat.one_mul] + -- Goal: m * m + m + (m + 1 * 1) = m * m + 2 * m + 1 + rw [Nat.one_mul]; omega + +/-- 512-bit ceiling square root. -/ +noncomputable def sqrtUp512 (x : Nat) : Nat := + let r := sqrt512 x + if r * r < x then r + 1 else r + +/-- sqrtUp512 is the ceiling sqrt: x <= r^2 and r is minimal. -/ +theorem sqrtUp512_correct (x : Nat) (hx : x < 2 ^ 512) : + let r := sqrtUp512 x + x ≤ r * r ∧ ∀ y, x ≤ y * y → r ≤ y := by + simp only + have hsqrt := sqrt512_correct x hx + have hs_lo : sqrt512 x * sqrt512 x ≤ x := by rw [hsqrt]; exact natSqrt_sq_le x + have hs_hi : x < (sqrt512 x + 1) * (sqrt512 x + 1) := by rw [hsqrt]; exact natSqrt_lt_succ_sq x + unfold sqrtUp512 + simp only + by_cases hlt : sqrt512 x * sqrt512 x < x + · -- s^2 < x: ceiling is s + 1 + simp [hlt] + exact ⟨by omega, fun y hy => by + suffices h : ¬(y < sqrt512 x + 1) by omega + intro hc + have hc' : y ≤ sqrt512 x := by omega + have := Nat.mul_le_mul hc' hc'; omega⟩ + · -- s^2 = x: ceiling is s + simp [hlt] + have hseq : sqrt512 x * sqrt512 x = x := by omega + exact ⟨by omega, fun y hy => by + suffices h : ¬(y < sqrt512 x) by omega + intro hc + have hc' : y ≤ sqrt512 x - 1 := by omega + have h1 := Nat.mul_le_mul hc' hc' + -- y*y ≤ (s-1)*(s-1) and (s-1)*(s-1) < s*s = x + -- (s-1)*(s-1) < s*s since (s-1+1)*(s-1+1) = (s-1)*(s-1) + 2*(s-1) + 1 + have h2 : 0 < sqrt512 x := by omega + have h3 := sq_expand_aux (sqrt512 x - 1) + -- h3 : (s-1+1)*(s-1+1) = (s-1)*(s-1) + 2*(s-1) + 1 + have h4 : (sqrt512 x - 1) + 1 = sqrt512 x := by omega + rw [h4] at h3 + -- h3 : s*s = (s-1)*(s-1) + 2*(s-1) + 1 + -- so (s-1)*(s-1) = s*s - 2*(s-1) - 1 < s*s = x + -- and y*y ≤ (s-1)*(s-1) < x, contradicting x ≤ y*y + omega⟩ + +/-- sqrtUp512 satisfies the ceiling sqrt spec. -/ +theorem sqrtUp512_spec (x : Nat) (hx : x < 2 ^ 512) : + let r := sqrtUp512 x + x ≤ r * r ∧ (r = 0 ∨ (r - 1) * (r - 1) < x) := by + have ⟨h1, h2⟩ := sqrtUp512_correct x hx + simp only at h1 h2 ⊢ + refine ⟨h1, ?_⟩ + by_cases hr0 : sqrtUp512 x = 0 + · left; exact hr0 + · right + suffices h : ¬((sqrtUp512 x - 1) * (sqrtUp512 x - 1) ≥ x) by omega + intro hc + have := h2 (sqrtUp512 x - 1) hc + omega diff --git a/formal/sqrt/Sqrt512Proof/Sqrt512Proof/SqrtWrapperSpec.lean b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/SqrtWrapperSpec.lean new file mode 100644 index 000000000..2ee56555d --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/Sqrt512Proof/SqrtWrapperSpec.lean @@ -0,0 +1,193 @@ +/- + Bridge proof: model_sqrt512_wrapper_evm computes natSqrt. + + The auto-generated model_sqrt512_wrapper_evm dispatches: + x_hi = 0 ⟹ inlined 256-bit floor sqrt (= model_sqrt_floor_evm from SqrtProof) + x_hi > 0 ⟹ model_sqrt512_evm (already proved correct) +-/ +import Sqrt512Proof.GeneratedSqrt512Model +import Sqrt512Proof.GeneratedSqrt512Spec +import SqrtProof.GeneratedSqrtModel +import SqrtProof.GeneratedSqrtSpec +import SqrtProof.SqrtCorrect + +namespace Sqrt512Spec + +open Sqrt512GeneratedModel + +-- ============================================================================ +-- Section 1: Namespace compatibility +-- Both SqrtGeneratedModel and Sqrt512GeneratedModel define identical opcodes. +-- We prove extensional equality so we can rewrite the wrapper's x_hi=0 branch +-- from Sqrt512GeneratedModel ops to SqrtGeneratedModel ops. +-- ============================================================================ + +section NamespaceCompat + +theorem WORD_MOD_compat : + @Sqrt512GeneratedModel.WORD_MOD = @SqrtGeneratedModel.WORD_MOD := rfl + +theorem u256_compat (x : Nat) : + Sqrt512GeneratedModel.u256 x = SqrtGeneratedModel.u256 x := by + unfold Sqrt512GeneratedModel.u256 SqrtGeneratedModel.u256 + rw [WORD_MOD_compat] + +theorem evmAdd_compat (a b : Nat) : + Sqrt512GeneratedModel.evmAdd a b = SqrtGeneratedModel.evmAdd a b := by + unfold Sqrt512GeneratedModel.evmAdd SqrtGeneratedModel.evmAdd + simp [u256_compat] + +theorem evmSub_compat (a b : Nat) : + Sqrt512GeneratedModel.evmSub a b = SqrtGeneratedModel.evmSub a b := by + unfold Sqrt512GeneratedModel.evmSub SqrtGeneratedModel.evmSub + simp [u256_compat, WORD_MOD_compat] + +theorem evmMul_compat (a b : Nat) : + Sqrt512GeneratedModel.evmMul a b = SqrtGeneratedModel.evmMul a b := by + unfold Sqrt512GeneratedModel.evmMul SqrtGeneratedModel.evmMul + simp [u256_compat] + +theorem evmDiv_compat (a b : Nat) : + Sqrt512GeneratedModel.evmDiv a b = SqrtGeneratedModel.evmDiv a b := by + unfold Sqrt512GeneratedModel.evmDiv SqrtGeneratedModel.evmDiv + simp [u256_compat] + +theorem evmShl_compat (s v : Nat) : + Sqrt512GeneratedModel.evmShl s v = SqrtGeneratedModel.evmShl s v := by + unfold Sqrt512GeneratedModel.evmShl SqrtGeneratedModel.evmShl + simp [u256_compat] + +theorem evmShr_compat (s v : Nat) : + Sqrt512GeneratedModel.evmShr s v = SqrtGeneratedModel.evmShr s v := by + unfold Sqrt512GeneratedModel.evmShr SqrtGeneratedModel.evmShr + simp [u256_compat] + +theorem evmClz_compat (v : Nat) : + Sqrt512GeneratedModel.evmClz v = SqrtGeneratedModel.evmClz v := by + unfold Sqrt512GeneratedModel.evmClz SqrtGeneratedModel.evmClz + simp [u256_compat] + +theorem evmLt_compat (a b : Nat) : + Sqrt512GeneratedModel.evmLt a b = SqrtGeneratedModel.evmLt a b := by + unfold Sqrt512GeneratedModel.evmLt SqrtGeneratedModel.evmLt + simp [u256_compat] + +theorem evmEq_compat (a b : Nat) : + Sqrt512GeneratedModel.evmEq a b = SqrtGeneratedModel.evmEq a b := by + unfold Sqrt512GeneratedModel.evmEq SqrtGeneratedModel.evmEq + simp [u256_compat] + +theorem evmGt_compat (a b : Nat) : + Sqrt512GeneratedModel.evmGt a b = SqrtGeneratedModel.evmGt a b := by + unfold Sqrt512GeneratedModel.evmGt SqrtGeneratedModel.evmGt + simp [u256_compat] + +theorem evmNot_compat (a : Nat) : + Sqrt512GeneratedModel.evmNot a = SqrtGeneratedModel.evmNot a := by + unfold Sqrt512GeneratedModel.evmNot SqrtGeneratedModel.evmNot + simp [u256_compat, WORD_MOD_compat] + +theorem evmMulmod_compat (a b n : Nat) : + Sqrt512GeneratedModel.evmMulmod a b n = SqrtGeneratedModel.evmMulmod a b n := by + unfold Sqrt512GeneratedModel.evmMulmod SqrtGeneratedModel.evmMulmod + simp [u256_compat] + +end NamespaceCompat + +-- ============================================================================ +-- Section 2: The wrapper's x_hi=0 branch equals model_sqrt_floor_evm +-- ============================================================================ + +/-- u256 is idempotent: u256(u256(x)) = u256(x). -/ +private theorem u256_idem (x : Nat) : + Sqrt512GeneratedModel.u256 (Sqrt512GeneratedModel.u256 x) = Sqrt512GeneratedModel.u256 x := by + unfold Sqrt512GeneratedModel.u256 Sqrt512GeneratedModel.WORD_MOD + exact Nat.mod_eq_of_lt (Nat.mod_lt x (Nat.two_pow_pos 256)) + +theorem su256_idem (x : Nat) : + SqrtGeneratedModel.u256 (SqrtGeneratedModel.u256 x) = SqrtGeneratedModel.u256 x := by + unfold SqrtGeneratedModel.u256 SqrtGeneratedModel.WORD_MOD + exact Nat.mod_eq_of_lt (Nat.mod_lt x (Nat.two_pow_pos 256)) + +theorem su256_zero : SqrtGeneratedModel.u256 0 = 0 := by + unfold SqrtGeneratedModel.u256 SqrtGeneratedModel.WORD_MOD; simp + +/-- When x_hi = 0, model_sqrt512_wrapper_evm calls model_sqrt256_floor_evm, + which is identical (modulo namespace) to model_sqrt_floor_evm from SqrtProof. -/ +theorem wrapper_zero_eq_sqrt_floor_evm (x_lo : Nat) : + model_sqrt512_wrapper_evm 0 x_lo = SqrtGeneratedModel.model_sqrt_floor_evm x_lo := by + -- Unfold all model definitions to expose the full EVM expression + simp only [model_sqrt512_wrapper_evm, model_sqrt256_floor_evm, + SqrtGeneratedModel.model_sqrt_floor_evm, SqrtGeneratedModel.model_sqrt_evm] + -- Convert Sqrt512 namespace ops to SqrtGeneratedModel ops + simp only [evmEq_compat, evmShr_compat, evmAdd_compat, evmDiv_compat, + evmSub_compat, evmClz_compat, evmShl_compat, evmLt_compat, u256_compat] + -- Simplify: u256(u256(x)) = u256(x) and u256(0) = 0 + simp only [su256_zero, su256_idem] + -- Simplify the conditional: if True then 1 else 0 = 1, 1 ≠ 0 = True, take then-branch + simp (config := { decide := true }) + +-- ============================================================================ +-- Section 3: natSqrt uniqueness bridge +-- ============================================================================ + +/-- The integer square root is unique: if r² ≤ n < (r+1)² then r = natSqrt n. -/ +theorem natSqrt_unique (n r : Nat) (hlo : r * r ≤ n) (hhi : n < (r + 1) * (r + 1)) : + r = natSqrt n := by + have hs := natSqrt_spec n + -- natSqrt n * natSqrt n ≤ n ∧ n < (natSqrt n + 1) * (natSqrt n + 1) + suffices h : ¬(r < natSqrt n) ∧ ¬(natSqrt n < r) by omega + constructor + · intro hlt + have hle : r + 1 ≤ natSqrt n := by omega + have := Nat.mul_le_mul hle hle + omega + · intro hlt + have hle : natSqrt n + 1 ≤ r := by omega + have := Nat.mul_le_mul hle hle + omega + +/-- floorSqrt = natSqrt for uint256 inputs. -/ +theorem floorSqrt_eq_natSqrt (x : Nat) (hx : x < 2 ^ 256) : + floorSqrt x = natSqrt x := by + have ⟨hlo, hhi⟩ := floorSqrt_correct_u256 x hx + exact natSqrt_unique x (floorSqrt x) hlo hhi + +-- ============================================================================ +-- Section 4: Main theorem — model_sqrt512_wrapper_evm = natSqrt +-- ============================================================================ + +set_option exponentiation.threshold 512 in +/-- The EVM model of the sqrt(uint512) wrapper computes natSqrt. -/ +theorem model_sqrt512_wrapper_evm_correct (x_hi x_lo : Nat) + (hxhi : x_hi < 2 ^ 256) (hxlo : x_lo < 2 ^ 256) : + model_sqrt512_wrapper_evm x_hi x_lo = natSqrt (x_hi * 2 ^ 256 + x_lo) := by + by_cases hxhi0 : x_hi = 0 + · -- x_hi = 0: the wrapper uses the inlined 256-bit floor sqrt + subst hxhi0 + simp only [Nat.zero_mul, Nat.zero_add] + -- Step 1: wrapper's x_hi=0 branch = model_sqrt_floor_evm x_lo + rw [wrapper_zero_eq_sqrt_floor_evm x_lo] + -- Step 2: model_sqrt_floor_evm = floorSqrt + rw [SqrtGeneratedModel.model_sqrt_floor_evm_eq_floorSqrt x_lo hxlo] + -- Step 3: floorSqrt = natSqrt + exact floorSqrt_eq_natSqrt x_lo hxlo + · -- x_hi > 0: use the existing model_sqrt512_evm_correct + have hxhi_pos : 0 < x_hi := Nat.pos_of_ne_zero hxhi0 + -- The wrapper's else-branch calls model_sqrt512_evm directly. + -- After unfolding the wrapper, the else-branch is model_sqrt512_evm (u256 x_hi) (u256 x_lo). + -- Since x_hi, x_lo < 2^256, u256 is identity. + unfold model_sqrt512_wrapper_evm + have hxhi_u : u256 x_hi = x_hi := u256_id' x_hi (by rwa [WORD_MOD]) + have hxlo_u : u256 x_lo = x_lo := u256_id' x_lo (by rwa [WORD_MOD]) + simp only [hxhi_u, hxlo_u] + -- evmEq x_hi 0 = 0 when x_hi > 0 + have hneq : evmEq x_hi 0 = 0 := by + unfold evmEq + simp [u256_id' x_hi (by rwa [WORD_MOD] : x_hi < WORD_MOD)] + exact Nat.ne_of_gt hxhi_pos + simp only [hneq] + -- Now goal is: model_sqrt512_evm x_hi x_lo = natSqrt (x_hi * 2^256 + x_lo) + exact model_sqrt512_evm_correct x_hi x_lo hxhi_pos hxhi hxlo + +end Sqrt512Spec diff --git a/formal/sqrt/Sqrt512Proof/lakefile.toml b/formal/sqrt/Sqrt512Proof/lakefile.toml new file mode 100644 index 000000000..3daa9e7df --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/lakefile.toml @@ -0,0 +1,14 @@ +name = "Sqrt512Proof" +version = "0.1.0" +defaultTargets = ["Sqrt512Proof"] + +[[lean_lib]] +name = "Sqrt512Proof" + +[[lean_exe]] +name = "sqrt512-model" +root = "Main" + +[[require]] +name = "SqrtProof" +path = "../SqrtProof" diff --git a/formal/sqrt/Sqrt512Proof/lean-toolchain b/formal/sqrt/Sqrt512Proof/lean-toolchain new file mode 100644 index 000000000..4c685fa08 --- /dev/null +++ b/formal/sqrt/Sqrt512Proof/lean-toolchain @@ -0,0 +1 @@ +leanprover/lean4:v4.28.0 diff --git a/formal/sqrt/SqrtProof/.gitignore b/formal/sqrt/SqrtProof/.gitignore new file mode 100644 index 000000000..78233b5a4 --- /dev/null +++ b/formal/sqrt/SqrtProof/.gitignore @@ -0,0 +1,7 @@ +/.lake + +# Auto-generated from `formal/sqrt/generate_sqrt_cert.py` +/SqrtProof/FiniteCert.lean + +# Auto-generated from `formal/sqrt/generate_sqrt_model.py` +/SqrtProof/GeneratedSqrtModel.lean diff --git a/formal/sqrt/SqrtProof/Main.lean b/formal/sqrt/SqrtProof/Main.lean new file mode 100644 index 000000000..85210ff03 --- /dev/null +++ b/formal/sqrt/SqrtProof/Main.lean @@ -0,0 +1,52 @@ +import SqrtProof.GeneratedSqrtModel + +/-! +# Sqrt model evaluator + +Compiled executable for evaluating the generated EVM-faithful Sqrt model +on concrete inputs. Intended for fuzz testing via Foundry's `vm.ffi`. + +Usage: + sqrt-model + +Functions: sqrt, sqrt_floor, sqrt_up + +Output: 0x-prefixed hex uint256 on stdout. +-/ + +open SqrtGeneratedModel in +def evalFunction (name : String) (x : Nat) : Option Nat := + match name with + | "sqrt" => some (model_sqrt_evm x) + | "sqrt_floor" => some (model_sqrt_floor_evm x) + | "sqrt_up" => some (model_sqrt_up_evm x) + | _ => none + +def natToHex64 (n : Nat) : String := + let hex := String.ofList (Nat.toDigits 16 n) + "0x" ++ String.ofList (List.replicate (64 - hex.length) '0') ++ hex + +def parseHex (s : String) : Option Nat := + let s := if s.startsWith "0x" || s.startsWith "0X" then s.drop 2 else s + s.foldl (fun acc c => + acc.bind fun n => + if '0' ≤ c && c ≤ '9' then some (n * 16 + (c.toNat - '0'.toNat)) + else if 'a' ≤ c && c ≤ 'f' then some (n * 16 + (c.toNat - 'a'.toNat + 10)) + else if 'A' ≤ c && c ≤ 'F' then some (n * 16 + (c.toNat - 'A'.toNat + 10)) + else none + ) (some 0) + +def main (args : List String) : IO UInt32 := do + match args with + | [fnName, hexX] => + match parseHex hexX with + | none => IO.eprintln s!"Invalid hex input: {hexX}"; return 1 + | some x => + match evalFunction fnName x with + | none => IO.eprintln s!"Unknown function: {fnName}"; return 1 + | some result => + IO.println (natToHex64 result) + return 0 + | _ => + IO.eprintln "Usage: sqrt-model " + return 1 diff --git a/formal/sqrt/SqrtProof/SqrtProof.lean b/formal/sqrt/SqrtProof/SqrtProof.lean new file mode 100644 index 000000000..5fc3301c9 --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof.lean @@ -0,0 +1,10 @@ +-- This module serves as the root of the `SqrtProof` library. +-- Import modules here that should be built as part of the library. +import SqrtProof.FloorBound +import SqrtProof.StepMono +import SqrtProof.BridgeLemmas +import SqrtProof.FiniteCert +import SqrtProof.CertifiedChain +import SqrtProof.SqrtCorrect +import SqrtProof.GeneratedSqrtModel +import SqrtProof.GeneratedSqrtSpec diff --git a/formal/sqrt/SqrtProof/SqrtProof/BridgeLemmas.lean b/formal/sqrt/SqrtProof/SqrtProof/BridgeLemmas.lean new file mode 100644 index 000000000..10f83002a --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/BridgeLemmas.lean @@ -0,0 +1,175 @@ +import Init +import SqrtProof.FloorBound + +namespace SqrtBridge + +private theorem hmul2 (a b : Nat) : a * (2 * b) = 2 * (a * b) := by + calc + a * (2 * b) = (a * 2) * b := by rw [Nat.mul_assoc] + _ = (2 * a) * b := by rw [Nat.mul_comm a 2] + _ = 2 * (a * b) := by rw [Nat.mul_assoc] + +private theorem div_split (m d : Nat) (hmd : d ≤ m) : + (m * m + 2 * m) / (m + d) = (m - d) + (d * d + 2 * m) / (m + d) := by + by_cases hzero : m + d = 0 + · have hm0 : m = 0 := by omega + have hd0 : d = 0 := by omega + subst hm0; subst hd0 + simp + · have hsq0 := sq_identity_ge (m + d) m (by omega) (by omega) + have hsq : (m + d) * (m - d) + d * d = m * m := by + have hsub : 2 * m - (m + d) = m - d := by omega + have hdm : (m + d) - m = d := by rw [Nat.add_sub_cancel_left] + rw [hsub, hdm] at hsq0 + exact hsq0 + have hmul : (m + d) * (m - d) + (d * d + 2 * m) = m * m + 2 * m := by + rw [← hsq] + omega + rw [← hmul] + have hpos : 0 < m + d := by omega + rw [Nat.mul_add_div hpos] + +private theorem rhs_eq (s m : Nat) (hs : s ≤ m) : + s * s + m * m + 2 * m - 2 * (s * m) = (m - s) * (m - s) + 2 * m := by + have hsq := sq_identity_le s m hs + rw [← hsq] + rw [Nat.mul_sub, hmul2] + let A := (m - s) * (m - s) + change s * s + (2 * (s * m) - s * s + A) + 2 * m - 2 * (s * m) = A + 2 * m + have hs2 : s * s ≤ 2 * (s * m) := by + have hsm : s * s ≤ s * m := Nat.mul_le_mul_left s hs + omega + have hpre : s * s + (2 * (s * m) - s * s + A) + 2 * m + = (2 * (s * m)) + A + 2 * m := by + omega + rw [hpre] + omega + +private theorem rhs_eq_rev (s m : Nat) (hs : m ≤ s) : + s * s + m * m + 2 * m - 2 * (s * m) = (s - m) * (s - m) + 2 * m := by + have hsq := sq_identity_le m s hs + rw [← hsq] + rw [Nat.mul_sub, hmul2] + let A := (s - m) * (s - m) + have hcomm : 2 * (s * m) = 2 * (m * s) := by rw [Nat.mul_comm s m] + rw [hcomm] + have hm2 : m * m ≤ 2 * (m * s) := by + have hms : m * m ≤ m * s := Nat.mul_le_mul_left m hs + omega + have hpre : 2 * (m * s) - m * m + A + m * m + 2 * m + = 2 * (m * s) + A + 2 * m := by + have hsubadd : 2 * (m * s) - m * m + m * m = 2 * (m * s) := Nat.sub_add_cancel hm2 + omega + rw [hpre] + omega + +/-- One-step error contraction for `z = m + d` with `d ≤ m`. + This is the recurrence used by the finite-certificate bridge. -/ +theorem step_error_bound + (m d x : Nat) + (hm : 0 < m) + (hmd : d ≤ m) + (hxhi : x < (m + 1) * (m + 1)) : + bstep x (m + d) - m ≤ d * d / (2 * m) + 1 := by + unfold bstep + have hxhi' : x < m * m + (m + m) + 1 := by + simpa [Nat.add_mul, Nat.mul_add, Nat.mul_one, Nat.one_mul, + Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hxhi + have hx_le : x ≤ m * m + 2 * m := by omega + have hdiv : x / (m + d) ≤ (m * m + 2 * m) / (m + d) := Nat.div_le_div_right hx_le + rw [div_split m d hmd] at hdiv + have hsum : m + d + x / (m + d) ≤ 2 * m + (d * d + 2 * m) / (m + d) := by + omega + have hhalf : (m + d + x / (m + d)) / 2 ≤ (2 * m + (d * d + 2 * m) / (m + d)) / 2 := + Nat.div_le_div_right hsum + have hsub : (m + d + x / (m + d)) / 2 - m ≤ ((2 * m + (d * d + 2 * m) / (m + d)) / 2) - m := + Nat.sub_le_sub_right hhalf m + have hright : ((2 * m + (d * d + 2 * m) / (m + d)) / 2) - m + = ((d * d + 2 * m) / (m + d)) / 2 := by + let q := (d * d + 2 * m) / (m + d) + have htmp : (2 * m + q) / 2 = m + q / 2 := by + have hswap : 2 * m + q = q + m * 2 := by omega + rw [hswap, Nat.add_mul_div_right q m (by decide : 0 < 2)] + omega + rw [htmp, Nat.add_sub_cancel_left] + rw [hright] at hsub + have hden : m ≤ m + d := by omega + have hdiv2 : (d * d + 2 * m) / (m + d) ≤ (d * d + 2 * m) / m := + Nat.div_le_div_left hden hm + have hhalf2 : ((d * d + 2 * m) / (m + d)) / 2 ≤ ((d * d + 2 * m) / m) / 2 := + Nat.div_le_div_right hdiv2 + have hmain : ((d * d + 2 * m) / m) / 2 = d * d / (2 * m) + 1 := by + rw [Nat.div_div_eq_div_mul, Nat.mul_comm m 2] + have hsum2 : d * d + 2 * m = d * d + 1 * (2 * m) := by omega + rw [hsum2, Nat.add_mul_div_right (d * d) 1 (by omega : 0 < 2 * m)] + have hbound : (m + d + x / (m + d)) / 2 - m ≤ ((d * d + 2 * m) / m) / 2 := + Nat.le_trans hsub hhalf2 + exact Nat.le_trans hbound (by simp [hmain]) + +/-- Upper bound for the first post-seed error `d₁ = bstep x s - m`, using only + `m ∈ [lo, hi]` and the interval constraint `m² ≤ x < (m+1)²`. -/ +theorem d1_bound + (x m s lo hi : Nat) + (hs : 0 < s) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hlo : lo ≤ m) + (hhi : m ≤ hi) : + let maxAbs := max (s - lo) (hi - s) + bstep x s - m ≤ (maxAbs * maxAbs + 2 * hi) / (2 * s) := by + unfold bstep + simp only + have hmstep : m ≤ (s + x / s) / 2 := babylon_step_floor_bound x s m hs hmlo + have hmulsub : 2 * s * ((s + x / s) / 2 - m) = 2 * s * ((s + x / s) / 2) - 2 * s * m := by + rw [Nat.mul_sub] + have h2z : 2 * ((s + x / s) / 2) ≤ s + x / s := Nat.mul_div_le (s + x / s) 2 + have h2z_mul : 2 * s * ((s + x / s) / 2) ≤ s * (s + x / s) := by + have := Nat.mul_le_mul_left s h2z + simpa [Nat.mul_assoc, Nat.mul_comm 2 s] using this + have hsub : 2 * s * ((s + x / s) / 2 - m) ≤ s * (s + x / s) - 2 * s * m := by + have hsub' : 2 * s * ((s + x / s) / 2) - 2 * s * m ≤ s * (s + x / s) - 2 * s * m := + Nat.sub_le_sub_right h2z_mul (2 * s * m) + simpa [hmulsub] using hsub' + have hdivmul : s * (x / s) ≤ x := Nat.mul_div_le x s + have hnum1 : s * (s + x / s) - 2 * s * m ≤ s * s + x - 2 * s * m := by + have hpre : s * (s + x / s) = s * s + s * (x / s) := by rw [Nat.mul_add] + rw [hpre] + exact Nat.sub_le_sub_right (Nat.add_le_add_left hdivmul (s * s)) (2 * s * m) + have hnum2 : s * s + x - 2 * s * m ≤ s * s + m * m + 2 * m - 2 * s * m := by + have hmhi' : x < m * m + (m + m) + 1 := by + simpa [Nat.add_mul, Nat.mul_add, Nat.mul_one, Nat.one_mul, + Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hmhi + have hx_le : x ≤ m * m + 2 * m := by omega + omega + have hnum : 2 * s * ((s + x / s) / 2 - m) ≤ s * s + m * m + 2 * m - 2 * s * m := + Nat.le_trans hsub (Nat.le_trans hnum1 hnum2) + let maxAbs := max (s - lo) (hi - s) + have hs2 : 0 < 2 * s := by omega + by_cases hsm : s ≤ m + · have hr : s * s + m * m + 2 * m - 2 * s * m = (m - s) * (m - s) + 2 * m := by + simpa [Nat.mul_assoc] using rhs_eq s m hsm + rw [hr] at hnum + have hds : m - s ≤ hi - s := Nat.sub_le_sub_right hhi s + have hsq : (m - s) * (m - s) ≤ (hi - s) * (hi - s) := Nat.mul_le_mul hds hds + have hsq' : (hi - s) * (hi - s) ≤ maxAbs * maxAbs := by + have hmmax : hi - s ≤ maxAbs := Nat.le_max_right (s - lo) (hi - s) + exact Nat.mul_le_mul hmmax hmmax + have h2m : 2 * m ≤ 2 * hi := by omega + have hfin : 2 * s * ((s + x / s) / 2 - m) ≤ maxAbs * maxAbs + 2 * hi := by + exact Nat.le_trans hnum (Nat.add_le_add (Nat.le_trans hsq hsq') h2m) + exact (Nat.le_div_iff_mul_le hs2).2 (by simpa [Nat.mul_assoc, Nat.mul_comm, Nat.mul_left_comm] using hfin) + · have hms : m ≤ s := by omega + have hr : s * s + m * m + 2 * m - 2 * s * m = (s - m) * (s - m) + 2 * m := by + simpa [Nat.mul_assoc] using rhs_eq_rev s m hms + rw [hr] at hnum + have hds : s - m ≤ s - lo := Nat.sub_le_sub_left hlo s + have hsq : (s - m) * (s - m) ≤ (s - lo) * (s - lo) := Nat.mul_le_mul hds hds + have hsq' : (s - lo) * (s - lo) ≤ maxAbs * maxAbs := by + have hmmax : s - lo ≤ maxAbs := Nat.le_max_left (s - lo) (hi - s) + exact Nat.mul_le_mul hmmax hmmax + have h2m : 2 * m ≤ 2 * hi := by omega + have hfin : 2 * s * ((s + x / s) / 2 - m) ≤ maxAbs * maxAbs + 2 * hi := by + exact Nat.le_trans hnum (Nat.add_le_add (Nat.le_trans hsq hsq') h2m) + exact (Nat.le_div_iff_mul_le hs2).2 (by simpa [Nat.mul_assoc, Nat.mul_comm, Nat.mul_left_comm] using hfin) + +end SqrtBridge diff --git a/formal/sqrt/SqrtProof/SqrtProof/CertifiedChain.lean b/formal/sqrt/SqrtProof/SqrtProof/CertifiedChain.lean new file mode 100644 index 000000000..8c4e0d97a --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/CertifiedChain.lean @@ -0,0 +1,211 @@ +import Init +import SqrtProof.FloorBound +import SqrtProof.BridgeLemmas +import SqrtProof.FiniteCert + +namespace SqrtCertified + +open SqrtBridge +open SqrtCert + +def run6From (x z : Nat) : Nat := + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + z + +theorem step_from_bound + (x m lo z D : Nat) + (hm : 0 < m) + (hloPos : 0 < lo) + (hlo : lo ≤ m) + (hxhi : x < (m + 1) * (m + 1)) + (hmz : m ≤ z) + (hzD : z - m ≤ D) + (hDle : D ≤ m) : + bstep x z - m ≤ nextD lo D := by + have hz' : m + (z - m) = z := by omega + have hdle : z - m ≤ m := Nat.le_trans hzD hDle + have hstep := SqrtBridge.step_error_bound m (z - m) x hm hdle hxhi + have hstep' : bstep x z - m ≤ (z - m) * (z - m) / (2 * m) + 1 := by + simpa only [hz'] using hstep + have hsq : (z - m) * (z - m) ≤ D * D := Nat.mul_le_mul hzD hzD + have hdiv1 : (z - m) * (z - m) / (2 * m) ≤ D * D / (2 * m) := + Nat.div_le_div_right hsq + have hden : 2 * lo ≤ 2 * m := Nat.mul_le_mul_left 2 hlo + have hdiv2 : D * D / (2 * m) ≤ D * D / (2 * lo) := + Nat.div_le_div_left hden (by omega : 0 < 2 * lo) + have hfinal : (z - m) * (z - m) / (2 * m) + 1 ≤ D * D / (2 * lo) + 1 := + Nat.add_le_add_right (Nat.le_trans hdiv1 hdiv2) 1 + exact Nat.le_trans hstep' (by simpa [nextD] using hfinal) + +theorem run5_error_bounds + (i : Fin 256) + (x m : Nat) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hlo : loOf i ≤ m) + (hhi : m ≤ hiOf i) : + let z1 := bstep x (seedOf i) + let z2 := bstep x z1 + let z3 := bstep x z2 + let z4 := bstep x z3 + let z5 := bstep x z4 + z1 - m ≤ d1 i ∧ + z2 - m ≤ d2 i ∧ + z3 - m ≤ d3 i ∧ + z4 - m ≤ d4 i ∧ + z5 - m ≤ d5 i := by + let z1 := bstep x (seedOf i) + let z2 := bstep x z1 + let z3 := bstep x z2 + let z4 := bstep x z3 + let z5 := bstep x z4 + + have hs : 0 < seedOf i := by + have hpow : 0 < (2 : Nat) ^ ((i.val + 1) / 2) := Nat.pow_pos (by decide : 0 < (2 : Nat)) + simpa [seedOf, Nat.shiftLeft_eq, Nat.one_mul] using hpow + + have hmz1 : m ≤ z1 := by + dsimp [z1] + exact babylon_step_floor_bound x (seedOf i) m hs hmlo + have hz1Pos : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hd1 : z1 - m ≤ d1 i := by + have h := SqrtBridge.d1_bound x m (seedOf i) (loOf i) (hiOf i) hs hmlo hmhi hlo hhi + simpa [z1, d1, maxAbs] using h + have hd1m : d1 i ≤ m := Nat.le_trans (d1_le_lo i) hlo + + have hmz2 : m ≤ z2 := by + dsimp [z2] + exact babylon_step_floor_bound x z1 m hz1Pos hmlo + have hz2Pos : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hd2 : z2 - m ≤ d2 i := by + have h := step_from_bound x m (loOf i) z1 (d1 i) hm (lo_pos i) hlo hmhi hmz1 hd1 hd1m + simpa [z2, d2, nextD] using h + have hd2m : d2 i ≤ m := Nat.le_trans (d2_le_lo i) hlo + + have hmz3 : m ≤ z3 := by + dsimp [z3] + exact babylon_step_floor_bound x z2 m hz2Pos hmlo + have hz3Pos : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hd3 : z3 - m ≤ d3 i := by + have h := step_from_bound x m (loOf i) z2 (d2 i) hm (lo_pos i) hlo hmhi hmz2 hd2 hd2m + simpa [z3, d3, nextD] using h + have hd3m : d3 i ≤ m := Nat.le_trans (d3_le_lo i) hlo + + have hmz4 : m ≤ z4 := by + dsimp [z4] + exact babylon_step_floor_bound x z3 m hz3Pos hmlo + have hz4Pos : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hd4 : z4 - m ≤ d4 i := by + have h := step_from_bound x m (loOf i) z3 (d3 i) hm (lo_pos i) hlo hmhi hmz3 hd3 hd3m + simpa [z4, d4, nextD] using h + have hd4m : d4 i ≤ m := Nat.le_trans (d4_le_lo i) hlo + + have hmz5 : m ≤ z5 := by + dsimp [z5] + exact babylon_step_floor_bound x z4 m hz4Pos hmlo + have hd5 : z5 - m ≤ d5 i := by + have h := step_from_bound x m (loOf i) z4 (d4 i) hm (lo_pos i) hlo hmhi hmz4 hd4 hd4m + simpa [z5, d5, nextD] using h + + refine ⟨?_, ?_, ?_, ?_, ?_⟩ + · exact hd1 + · exact hd2 + · exact hd3 + · exact hd4 + · exact hd5 + +theorem run6_error_le_cert + (i : Fin 256) + (x m : Nat) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hlo : loOf i ≤ m) + (hhi : m ≤ hiOf i) : + run6From x (seedOf i) - m ≤ d6 i := by + let z1 := bstep x (seedOf i) + let z2 := bstep x z1 + let z3 := bstep x z2 + let z4 := bstep x z3 + let z5 := bstep x z4 + let z6 := bstep x z5 + + have hs : 0 < seedOf i := by + have hpow : 0 < (2 : Nat) ^ ((i.val + 1) / 2) := Nat.pow_pos (by decide : 0 < (2 : Nat)) + simpa [seedOf, Nat.shiftLeft_eq, Nat.one_mul] using hpow + + have hmz1 : m ≤ z1 := by + dsimp [z1] + exact babylon_step_floor_bound x (seedOf i) m hs hmlo + have hz1Pos : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hd1 : z1 - m ≤ d1 i := by + have h := SqrtBridge.d1_bound x m (seedOf i) (loOf i) (hiOf i) hs hmlo hmhi hlo hhi + simpa [z1, d1, maxAbs] using h + have hd1m : d1 i ≤ m := Nat.le_trans (d1_le_lo i) hlo + + have hmz2 : m ≤ z2 := by + dsimp [z2] + exact babylon_step_floor_bound x z1 m hz1Pos hmlo + have hz2Pos : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hd2 : z2 - m ≤ d2 i := by + have h := step_from_bound x m (loOf i) z1 (d1 i) hm (lo_pos i) hlo hmhi hmz1 hd1 hd1m + simpa [z2, d2, nextD] using h + have hd2m : d2 i ≤ m := Nat.le_trans (d2_le_lo i) hlo + + have hmz3 : m ≤ z3 := by + dsimp [z3] + exact babylon_step_floor_bound x z2 m hz2Pos hmlo + have hz3Pos : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hd3 : z3 - m ≤ d3 i := by + have h := step_from_bound x m (loOf i) z2 (d2 i) hm (lo_pos i) hlo hmhi hmz2 hd2 hd2m + simpa [z3, d3, nextD] using h + have hd3m : d3 i ≤ m := Nat.le_trans (d3_le_lo i) hlo + + have hmz4 : m ≤ z4 := by + dsimp [z4] + exact babylon_step_floor_bound x z3 m hz3Pos hmlo + have hz4Pos : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hd4 : z4 - m ≤ d4 i := by + have h := step_from_bound x m (loOf i) z3 (d3 i) hm (lo_pos i) hlo hmhi hmz3 hd3 hd3m + simpa [z4, d4, nextD] using h + have hd4m : d4 i ≤ m := Nat.le_trans (d4_le_lo i) hlo + + have hmz5 : m ≤ z5 := by + dsimp [z5] + exact babylon_step_floor_bound x z4 m hz4Pos hmlo + have hz5Pos : 0 < z5 := Nat.lt_of_lt_of_le hm hmz5 + have hd5 : z5 - m ≤ d5 i := by + have h := step_from_bound x m (loOf i) z4 (d4 i) hm (lo_pos i) hlo hmhi hmz4 hd4 hd4m + simpa [z5, d5, nextD] using h + have hd5m : d5 i ≤ m := Nat.le_trans (d5_le_lo i) hlo + + have hmz6 : m ≤ z6 := by + dsimp [z6] + exact babylon_step_floor_bound x z5 m hz5Pos hmlo + have hd6 : z6 - m ≤ d6 i := by + have h := step_from_bound x m (loOf i) z5 (d5 i) hm (lo_pos i) hlo hmhi hmz5 hd5 hd5m + simpa [z6, d6, nextD] using h + + simpa [run6From, z1, z2, z3, z4, z5, z6] using hd6 + +theorem run6_le_m_plus_one + (i : Fin 256) + (x m : Nat) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hlo : loOf i ≤ m) + (hhi : m ≤ hiOf i) : + run6From x (seedOf i) ≤ m + 1 := by + have herr := run6_error_le_cert i x m hm hmlo hmhi hlo hhi + have hsub : run6From x (seedOf i) - m ≤ 1 := Nat.le_trans herr (d6_le_one i) + have hzle : run6From x (seedOf i) ≤ 1 + m := (Nat.sub_le_iff_le_add).1 hsub + omega + +end SqrtCertified diff --git a/formal/sqrt/SqrtProof/SqrtProof/FiniteCert.lean b/formal/sqrt/SqrtProof/SqrtProof/FiniteCert.lean new file mode 100644 index 000000000..755590d7b --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/FiniteCert.lean @@ -0,0 +1,667 @@ +import Init + +/- + Finite certificate for sqrt upper bound, covering all 256 octaves. + + For each octave i (n = 0..255), the tables provide: + - loOf(i): lower bound on isqrt(x) for x in [2^i, 2^(i+1)-1] + - hiOf(i): upper bound on isqrt(x) + - seedOf(i): the sqrt seed for the octave = 1 <<< ((i+1)/2) + - maxAbs(i): max(|seed - lo|, |hi - seed|) + - d1(i): first-step error bound (analytic) + - nextD, d2..d6: chained error recurrence d^2/(2*lo) + 1 + + All 256 octaves verified: d6 <= 1 and dk <= lo for k=1..5. + + Auto-generated by formal/sqrt/generate_sqrt_cert.py — do not edit by hand. +-/ + +namespace SqrtCert + +set_option maxRecDepth 1000000 + +/-- Lower bounds on isqrt(x) for octaves 0..255. -/ +def loTable : Array Nat := #[ + 1, + 1, + 2, + 2, + 4, + 5, + 8, + 11, + 16, + 22, + 32, + 45, + 64, + 90, + 128, + 181, + 256, + 362, + 512, + 724, + 1024, + 1448, + 2048, + 2896, + 4096, + 5792, + 8192, + 11585, + 16384, + 23170, + 32768, + 46340, + 65536, + 92681, + 131072, + 185363, + 262144, + 370727, + 524288, + 741455, + 1048576, + 1482910, + 2097152, + 2965820, + 4194304, + 5931641, + 8388608, + 11863283, + 16777216, + 23726566, + 33554432, + 47453132, + 67108864, + 94906265, + 134217728, + 189812531, + 268435456, + 379625062, + 536870912, + 759250124, + 1073741824, + 1518500249, + 2147483648, + 3037000499, + 4294967296, + 6074000999, + 8589934592, + 12148001999, + 17179869184, + 24296003999, + 34359738368, + 48592007999, + 68719476736, + 97184015999, + 137438953472, + 194368031998, + 274877906944, + 388736063996, + 549755813888, + 777472127993, + 1099511627776, + 1554944255987, + 2199023255552, + 3109888511975, + 4398046511104, + 6219777023950, + 8796093022208, + 12439554047901, + 17592186044416, + 24879108095803, + 35184372088832, + 49758216191607, + 70368744177664, + 99516432383215, + 140737488355328, + 199032864766430, + 281474976710656, + 398065729532860, + 562949953421312, + 796131459065721, + 1125899906842624, + 1592262918131443, + 2251799813685248, + 3184525836262886, + 4503599627370496, + 6369051672525772, + 9007199254740992, + 12738103345051545, + 18014398509481984, + 25476206690103090, + 36028797018963968, + 50952413380206180, + 72057594037927936, + 101904826760412361, + 144115188075855872, + 203809653520824722, + 288230376151711744, + 407619307041649444, + 576460752303423488, + 815238614083298888, + 1152921504606846976, + 1630477228166597776, + 2305843009213693952, + 3260954456333195553, + 4611686018427387904, + 6521908912666391106, + 9223372036854775808, + 13043817825332782212, + 18446744073709551616, + 26087635650665564424, + 36893488147419103232, + 52175271301331128849, + 73786976294838206464, + 104350542602662257698, + 147573952589676412928, + 208701085205324515397, + 295147905179352825856, + 417402170410649030795, + 590295810358705651712, + 834804340821298061590, + 1180591620717411303424, + 1669608681642596123180, + 2361183241434822606848, + 3339217363285192246361, + 4722366482869645213696, + 6678434726570384492722, + 9444732965739290427392, + 13356869453140768985445, + 18889465931478580854784, + 26713738906281537970891, + 37778931862957161709568, + 53427477812563075941783, + 75557863725914323419136, + 106854955625126151883567, + 151115727451828646838272, + 213709911250252303767135, + 302231454903657293676544, + 427419822500504607534270, + 604462909807314587353088, + 854839645001009215068541, + 1208925819614629174706176, + 1709679290002018430137083, + 2417851639229258349412352, + 3419358580004036860274166, + 4835703278458516698824704, + 6838717160008073720548332, + 9671406556917033397649408, + 13677434320016147441096664, + 19342813113834066795298816, + 27354868640032294882193329, + 38685626227668133590597632, + 54709737280064589764386658, + 77371252455336267181195264, + 109419474560129179528773316, + 154742504910672534362390528, + 218838949120258359057546633, + 309485009821345068724781056, + 437677898240516718115093267, + 618970019642690137449562112, + 875355796481033436230186534, + 1237940039285380274899124224, + 1750711592962066872460373069, + 2475880078570760549798248448, + 3501423185924133744920746139, + 4951760157141521099596496896, + 7002846371848267489841492278, + 9903520314283042199192993792, + 14005692743696534979682984556, + 19807040628566084398385987584, + 28011385487393069959365969113, + 39614081257132168796771975168, + 56022770974786139918731938227, + 79228162514264337593543950336, + 112045541949572279837463876454, + 158456325028528675187087900672, + 224091083899144559674927752909, + 316912650057057350374175801344, + 448182167798289119349855505819, + 633825300114114700748351602688, + 896364335596578238699711011639, + 1267650600228229401496703205376, + 1792728671193156477399422023278, + 2535301200456458802993406410752, + 3585457342386312954798844046557, + 5070602400912917605986812821504, + 7170914684772625909597688093114, + 10141204801825835211973625643008, + 14341829369545251819195376186229, + 20282409603651670423947251286016, + 28683658739090503638390752372458, + 40564819207303340847894502572032, + 57367317478181007276781504744917, + 81129638414606681695789005144064, + 114734634956362014553563009489834, + 162259276829213363391578010288128, + 229469269912724029107126018979668, + 324518553658426726783156020576256, + 458938539825448058214252037959337, + 649037107316853453566312041152512, + 917877079650896116428504075918674, + 1298074214633706907132624082305024, + 1835754159301792232857008151837349, + 2596148429267413814265248164610048, + 3671508318603584465714016303674698, + 5192296858534827628530496329220096, + 7343016637207168931428032607349397, + 10384593717069655257060992658440192, + 14686033274414337862856065214698794, + 20769187434139310514121985316880384, + 29372066548828675725712130429397589, + 41538374868278621028243970633760768, + 58744133097657351451424260858795179, + 83076749736557242056487941267521536, + 117488266195314702902848521717590359, + 166153499473114484112975882535043072, + 234976532390629405805697043435180719, + 332306998946228968225951765070086144, + 469953064781258811611394086870361439, + 664613997892457936451903530140172288, + 939906129562517623222788173740722878, + 1329227995784915872903807060280344576, + 1879812259125035246445576347481445757, + 2658455991569831745807614120560689152, + 3759624518250070492891152694962891514, + 5316911983139663491615228241121378304, + 7519249036500140985782305389925783028, + 10633823966279326983230456482242756608, + 15038498073000281971564610779851566057, + 21267647932558653966460912964485513216, + 30076996146000563943129221559703132115, + 42535295865117307932921825928971026432, + 60153992292001127886258443119406264231, + 85070591730234615865843651857942052864, + 120307984584002255772516886238812528463, + 170141183460469231731687303715884105728, + 240615969168004511545033772477625056927 +] + +/-- Upper bounds on isqrt(x) for octaves 0..255. -/ +def hiTable : Array Nat := #[ + 1, + 1, + 2, + 3, + 5, + 7, + 11, + 15, + 22, + 31, + 45, + 63, + 90, + 127, + 181, + 255, + 362, + 511, + 724, + 1023, + 1448, + 2047, + 2896, + 4095, + 5792, + 8191, + 11585, + 16383, + 23170, + 32767, + 46340, + 65535, + 92681, + 131071, + 185363, + 262143, + 370727, + 524287, + 741455, + 1048575, + 1482910, + 2097151, + 2965820, + 4194303, + 5931641, + 8388607, + 11863283, + 16777215, + 23726566, + 33554431, + 47453132, + 67108863, + 94906265, + 134217727, + 189812531, + 268435455, + 379625062, + 536870911, + 759250124, + 1073741823, + 1518500249, + 2147483647, + 3037000499, + 4294967295, + 6074000999, + 8589934591, + 12148001999, + 17179869183, + 24296003999, + 34359738367, + 48592007999, + 68719476735, + 97184015999, + 137438953471, + 194368031998, + 274877906943, + 388736063996, + 549755813887, + 777472127993, + 1099511627775, + 1554944255987, + 2199023255551, + 3109888511975, + 4398046511103, + 6219777023950, + 8796093022207, + 12439554047901, + 17592186044415, + 24879108095803, + 35184372088831, + 49758216191607, + 70368744177663, + 99516432383215, + 140737488355327, + 199032864766430, + 281474976710655, + 398065729532860, + 562949953421311, + 796131459065721, + 1125899906842623, + 1592262918131443, + 2251799813685247, + 3184525836262886, + 4503599627370495, + 6369051672525772, + 9007199254740991, + 12738103345051545, + 18014398509481983, + 25476206690103090, + 36028797018963967, + 50952413380206180, + 72057594037927935, + 101904826760412361, + 144115188075855871, + 203809653520824722, + 288230376151711743, + 407619307041649444, + 576460752303423487, + 815238614083298888, + 1152921504606846975, + 1630477228166597776, + 2305843009213693951, + 3260954456333195553, + 4611686018427387903, + 6521908912666391106, + 9223372036854775807, + 13043817825332782212, + 18446744073709551615, + 26087635650665564424, + 36893488147419103231, + 52175271301331128849, + 73786976294838206463, + 104350542602662257698, + 147573952589676412927, + 208701085205324515397, + 295147905179352825855, + 417402170410649030795, + 590295810358705651711, + 834804340821298061590, + 1180591620717411303423, + 1669608681642596123180, + 2361183241434822606847, + 3339217363285192246361, + 4722366482869645213695, + 6678434726570384492722, + 9444732965739290427391, + 13356869453140768985445, + 18889465931478580854783, + 26713738906281537970891, + 37778931862957161709567, + 53427477812563075941783, + 75557863725914323419135, + 106854955625126151883567, + 151115727451828646838271, + 213709911250252303767135, + 302231454903657293676543, + 427419822500504607534270, + 604462909807314587353087, + 854839645001009215068541, + 1208925819614629174706175, + 1709679290002018430137083, + 2417851639229258349412351, + 3419358580004036860274166, + 4835703278458516698824703, + 6838717160008073720548332, + 9671406556917033397649407, + 13677434320016147441096664, + 19342813113834066795298815, + 27354868640032294882193329, + 38685626227668133590597631, + 54709737280064589764386658, + 77371252455336267181195263, + 109419474560129179528773316, + 154742504910672534362390527, + 218838949120258359057546633, + 309485009821345068724781055, + 437677898240516718115093267, + 618970019642690137449562111, + 875355796481033436230186534, + 1237940039285380274899124223, + 1750711592962066872460373069, + 2475880078570760549798248447, + 3501423185924133744920746139, + 4951760157141521099596496895, + 7002846371848267489841492278, + 9903520314283042199192993791, + 14005692743696534979682984556, + 19807040628566084398385987583, + 28011385487393069959365969113, + 39614081257132168796771975167, + 56022770974786139918731938227, + 79228162514264337593543950335, + 112045541949572279837463876454, + 158456325028528675187087900671, + 224091083899144559674927752909, + 316912650057057350374175801343, + 448182167798289119349855505819, + 633825300114114700748351602687, + 896364335596578238699711011639, + 1267650600228229401496703205375, + 1792728671193156477399422023278, + 2535301200456458802993406410751, + 3585457342386312954798844046557, + 5070602400912917605986812821503, + 7170914684772625909597688093114, + 10141204801825835211973625643007, + 14341829369545251819195376186229, + 20282409603651670423947251286015, + 28683658739090503638390752372458, + 40564819207303340847894502572031, + 57367317478181007276781504744917, + 81129638414606681695789005144063, + 114734634956362014553563009489834, + 162259276829213363391578010288127, + 229469269912724029107126018979668, + 324518553658426726783156020576255, + 458938539825448058214252037959337, + 649037107316853453566312041152511, + 917877079650896116428504075918674, + 1298074214633706907132624082305023, + 1835754159301792232857008151837349, + 2596148429267413814265248164610047, + 3671508318603584465714016303674698, + 5192296858534827628530496329220095, + 7343016637207168931428032607349397, + 10384593717069655257060992658440191, + 14686033274414337862856065214698794, + 20769187434139310514121985316880383, + 29372066548828675725712130429397589, + 41538374868278621028243970633760767, + 58744133097657351451424260858795179, + 83076749736557242056487941267521535, + 117488266195314702902848521717590359, + 166153499473114484112975882535043071, + 234976532390629405805697043435180719, + 332306998946228968225951765070086143, + 469953064781258811611394086870361439, + 664613997892457936451903530140172287, + 939906129562517623222788173740722878, + 1329227995784915872903807060280344575, + 1879812259125035246445576347481445757, + 2658455991569831745807614120560689151, + 3759624518250070492891152694962891514, + 5316911983139663491615228241121378303, + 7519249036500140985782305389925783028, + 10633823966279326983230456482242756607, + 15038498073000281971564610779851566057, + 21267647932558653966460912964485513215, + 30076996146000563943129221559703132115, + 42535295865117307932921825928971026431, + 60153992292001127886258443119406264231, + 85070591730234615865843651857942052863, + 120307984584002255772516886238812528463, + 170141183460469231731687303715884105727, + 240615969168004511545033772477625056927, + 340282366920938463463374607431768211455 +] + +def seedOf (i : Fin 256) : Nat := + 1 <<< ((i.val + 1) / 2) + +def loOf (i : Fin 256) : Nat := + loTable[i.val]! + +def hiOf (i : Fin 256) : Nat := + hiTable[i.val]! + +def maxAbs (i : Fin 256) : Nat := + max (seedOf i - loOf i) (hiOf i - seedOf i) + +def d1 (i : Fin 256) : Nat := + (maxAbs i * maxAbs i + 2 * hiOf i) / (2 * seedOf i) + +def nextD (lo d : Nat) : Nat := + d * d / (2 * lo) + 1 + +def d2 (i : Fin 256) : Nat := + nextD (loOf i) (d1 i) + +def d3 (i : Fin 256) : Nat := + nextD (loOf i) (d2 i) + +def d4 (i : Fin 256) : Nat := + nextD (loOf i) (d3 i) + +def d5 (i : Fin 256) : Nat := + nextD (loOf i) (d4 i) + +def d6 (i : Fin 256) : Nat := + nextD (loOf i) (d5 i) + +theorem lo_pos : ∀ i : Fin 256, 0 < loOf i := by + decide + +theorem d1_le_lo : ∀ i : Fin 256, d1 i ≤ loOf i := by + decide + +theorem d2_le_lo : ∀ i : Fin 256, d2 i ≤ loOf i := by + decide + +theorem d3_le_lo : ∀ i : Fin 256, d3 i ≤ loOf i := by + decide + +theorem d4_le_lo : ∀ i : Fin 256, d4 i ≤ loOf i := by + decide + +theorem d5_le_lo : ∀ i : Fin 256, d5 i ≤ loOf i := by + decide + +theorem d6_le_one : ∀ i : Fin 256, d6 i ≤ 1 := by + decide + +theorem lo_sq_le_pow2 : ∀ i : Fin 256, loOf i * loOf i ≤ 2 ^ i.val := by + decide + +theorem pow2_succ_le_hi_succ_sq : + ∀ i : Fin 256, 2 ^ (i.val + 1) ≤ (hiOf i + 1) * (hiOf i + 1) := by + decide + +end SqrtCert + +-- ============================================================================ +-- Sqrt512Cert: fixed-seed certificates for octaves 254/255 +-- Used by the 512-bit sqrt proof (Sqrt512Proof). +-- ============================================================================ + +namespace Sqrt512Cert + +open SqrtCert + +/-- The fixed Newton seed used by 512-bit sqrt: isqrt(2^255). + Equals hiOf(254) = loOf(255) in the finite certificate tables. -/ +def FIXED_SEED : Nat := 240615969168004511545033772477625056927 + +def lo254 : Nat := loOf ⟨254, by omega⟩ +def hi254 : Nat := hiOf ⟨254, by omega⟩ +def maxAbs254 : Nat := max (FIXED_SEED - lo254) (hi254 - FIXED_SEED) +def fd1_254 : Nat := (maxAbs254 * maxAbs254 + 2 * hi254) / (2 * FIXED_SEED) +def fd2_254 : Nat := nextD lo254 fd1_254 +def fd3_254 : Nat := nextD lo254 fd2_254 +def fd4_254 : Nat := nextD lo254 fd3_254 +def fd5_254 : Nat := nextD lo254 fd4_254 +def fd6_254 : Nat := nextD lo254 fd5_254 + +set_option maxRecDepth 100000 in +theorem fd6_254_le_one : fd6_254 ≤ 1 := by decide +set_option maxRecDepth 100000 in +theorem fd1_254_le_lo : fd1_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd2_254_le_lo : fd2_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd3_254_le_lo : fd3_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd4_254_le_lo : fd4_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd5_254_le_lo : fd5_254 ≤ lo254 := by decide +theorem lo254_pos : 0 < lo254 := lo_pos ⟨254, by omega⟩ + +def lo255 : Nat := loOf ⟨255, by omega⟩ +def hi255 : Nat := hiOf ⟨255, by omega⟩ +def maxAbs255 : Nat := max (FIXED_SEED - lo255) (hi255 - FIXED_SEED) +def fd1_255 : Nat := (maxAbs255 * maxAbs255 + 2 * hi255) / (2 * FIXED_SEED) +def fd2_255 : Nat := nextD lo255 fd1_255 +def fd3_255 : Nat := nextD lo255 fd2_255 +def fd4_255 : Nat := nextD lo255 fd3_255 +def fd5_255 : Nat := nextD lo255 fd4_255 +def fd6_255 : Nat := nextD lo255 fd5_255 + +set_option maxRecDepth 100000 in +theorem fd6_255_le_one : fd6_255 ≤ 1 := by decide +set_option maxRecDepth 100000 in +theorem fd1_255_le_lo : fd1_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd2_255_le_lo : fd2_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd3_255_le_lo : fd3_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd4_255_le_lo : fd4_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd5_255_le_lo : fd5_255 ≤ lo255 := by decide +theorem lo255_pos : 0 < lo255 := lo_pos ⟨255, by omega⟩ + +end Sqrt512Cert diff --git a/formal/sqrt/SqrtProof/SqrtProof/FiniteCertSymbolic.lean b/formal/sqrt/SqrtProof/SqrtProof/FiniteCertSymbolic.lean new file mode 100644 index 000000000..6f0303bb0 --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/FiniteCertSymbolic.lean @@ -0,0 +1,623 @@ +import Init + +/-! +Legacy symbolic certificate version (kept for reference). +This file retains the original recurrence + `native_decide` style proofs. +-/ + +namespace SqrtCert + +def loTable : Array Nat := #[ + 1, + 1, + 2, + 2, + 4, + 5, + 8, + 11, + 16, + 22, + 32, + 45, + 64, + 90, + 128, + 181, + 256, + 362, + 512, + 724, + 1024, + 1448, + 2048, + 2896, + 4096, + 5792, + 8192, + 11585, + 16384, + 23170, + 32768, + 46340, + 65536, + 92681, + 131072, + 185363, + 262144, + 370727, + 524288, + 741455, + 1048576, + 1482910, + 2097152, + 2965820, + 4194304, + 5931641, + 8388608, + 11863283, + 16777216, + 23726566, + 33554432, + 47453132, + 67108864, + 94906265, + 134217728, + 189812531, + 268435456, + 379625062, + 536870912, + 759250124, + 1073741824, + 1518500249, + 2147483648, + 3037000499, + 4294967296, + 6074000999, + 8589934592, + 12148001999, + 17179869184, + 24296003999, + 34359738368, + 48592007999, + 68719476736, + 97184015999, + 137438953472, + 194368031998, + 274877906944, + 388736063996, + 549755813888, + 777472127993, + 1099511627776, + 1554944255987, + 2199023255552, + 3109888511975, + 4398046511104, + 6219777023950, + 8796093022208, + 12439554047901, + 17592186044416, + 24879108095803, + 35184372088832, + 49758216191607, + 70368744177664, + 99516432383215, + 140737488355328, + 199032864766430, + 281474976710656, + 398065729532860, + 562949953421312, + 796131459065721, + 1125899906842624, + 1592262918131443, + 2251799813685248, + 3184525836262886, + 4503599627370496, + 6369051672525772, + 9007199254740992, + 12738103345051545, + 18014398509481984, + 25476206690103090, + 36028797018963968, + 50952413380206180, + 72057594037927936, + 101904826760412361, + 144115188075855872, + 203809653520824722, + 288230376151711744, + 407619307041649444, + 576460752303423488, + 815238614083298888, + 1152921504606846976, + 1630477228166597776, + 2305843009213693952, + 3260954456333195553, + 4611686018427387904, + 6521908912666391106, + 9223372036854775808, + 13043817825332782212, + 18446744073709551616, + 26087635650665564424, + 36893488147419103232, + 52175271301331128849, + 73786976294838206464, + 104350542602662257698, + 147573952589676412928, + 208701085205324515397, + 295147905179352825856, + 417402170410649030795, + 590295810358705651712, + 834804340821298061590, + 1180591620717411303424, + 1669608681642596123180, + 2361183241434822606848, + 3339217363285192246361, + 4722366482869645213696, + 6678434726570384492722, + 9444732965739290427392, + 13356869453140768985445, + 18889465931478580854784, + 26713738906281537970891, + 37778931862957161709568, + 53427477812563075941783, + 75557863725914323419136, + 106854955625126151883567, + 151115727451828646838272, + 213709911250252303767135, + 302231454903657293676544, + 427419822500504607534270, + 604462909807314587353088, + 854839645001009215068541, + 1208925819614629174706176, + 1709679290002018430137083, + 2417851639229258349412352, + 3419358580004036860274166, + 4835703278458516698824704, + 6838717160008073720548332, + 9671406556917033397649408, + 13677434320016147441096664, + 19342813113834066795298816, + 27354868640032294882193329, + 38685626227668133590597632, + 54709737280064589764386658, + 77371252455336267181195264, + 109419474560129179528773316, + 154742504910672534362390528, + 218838949120258359057546633, + 309485009821345068724781056, + 437677898240516718115093267, + 618970019642690137449562112, + 875355796481033436230186534, + 1237940039285380274899124224, + 1750711592962066872460373069, + 2475880078570760549798248448, + 3501423185924133744920746139, + 4951760157141521099596496896, + 7002846371848267489841492278, + 9903520314283042199192993792, + 14005692743696534979682984556, + 19807040628566084398385987584, + 28011385487393069959365969113, + 39614081257132168796771975168, + 56022770974786139918731938227, + 79228162514264337593543950336, + 112045541949572279837463876454, + 158456325028528675187087900672, + 224091083899144559674927752909, + 316912650057057350374175801344, + 448182167798289119349855505819, + 633825300114114700748351602688, + 896364335596578238699711011639, + 1267650600228229401496703205376, + 1792728671193156477399422023278, + 2535301200456458802993406410752, + 3585457342386312954798844046557, + 5070602400912917605986812821504, + 7170914684772625909597688093114, + 10141204801825835211973625643008, + 14341829369545251819195376186229, + 20282409603651670423947251286016, + 28683658739090503638390752372458, + 40564819207303340847894502572032, + 57367317478181007276781504744917, + 81129638414606681695789005144064, + 114734634956362014553563009489834, + 162259276829213363391578010288128, + 229469269912724029107126018979668, + 324518553658426726783156020576256, + 458938539825448058214252037959337, + 649037107316853453566312041152512, + 917877079650896116428504075918674, + 1298074214633706907132624082305024, + 1835754159301792232857008151837349, + 2596148429267413814265248164610048, + 3671508318603584465714016303674698, + 5192296858534827628530496329220096, + 7343016637207168931428032607349397, + 10384593717069655257060992658440192, + 14686033274414337862856065214698794, + 20769187434139310514121985316880384, + 29372066548828675725712130429397589, + 41538374868278621028243970633760768, + 58744133097657351451424260858795179, + 83076749736557242056487941267521536, + 117488266195314702902848521717590359, + 166153499473114484112975882535043072, + 234976532390629405805697043435180719, + 332306998946228968225951765070086144, + 469953064781258811611394086870361439, + 664613997892457936451903530140172288, + 939906129562517623222788173740722878, + 1329227995784915872903807060280344576, + 1879812259125035246445576347481445757, + 2658455991569831745807614120560689152, + 3759624518250070492891152694962891514, + 5316911983139663491615228241121378304, + 7519249036500140985782305389925783028, + 10633823966279326983230456482242756608, + 15038498073000281971564610779851566057, + 21267647932558653966460912964485513216, + 30076996146000563943129221559703132115, + 42535295865117307932921825928971026432, + 60153992292001127886258443119406264231, + 85070591730234615865843651857942052864, + 120307984584002255772516886238812528463, + 170141183460469231731687303715884105728, + 240615969168004511545033772477625056927 +] + +def hiTable : Array Nat := #[ + 1, + 1, + 2, + 3, + 5, + 7, + 11, + 15, + 22, + 31, + 45, + 63, + 90, + 127, + 181, + 255, + 362, + 511, + 724, + 1023, + 1448, + 2047, + 2896, + 4095, + 5792, + 8191, + 11585, + 16383, + 23170, + 32767, + 46340, + 65535, + 92681, + 131071, + 185363, + 262143, + 370727, + 524287, + 741455, + 1048575, + 1482910, + 2097151, + 2965820, + 4194303, + 5931641, + 8388607, + 11863283, + 16777215, + 23726566, + 33554431, + 47453132, + 67108863, + 94906265, + 134217727, + 189812531, + 268435455, + 379625062, + 536870911, + 759250124, + 1073741823, + 1518500249, + 2147483647, + 3037000499, + 4294967295, + 6074000999, + 8589934591, + 12148001999, + 17179869183, + 24296003999, + 34359738367, + 48592007999, + 68719476735, + 97184015999, + 137438953471, + 194368031998, + 274877906943, + 388736063996, + 549755813887, + 777472127993, + 1099511627775, + 1554944255987, + 2199023255551, + 3109888511975, + 4398046511103, + 6219777023950, + 8796093022207, + 12439554047901, + 17592186044415, + 24879108095803, + 35184372088831, + 49758216191607, + 70368744177663, + 99516432383215, + 140737488355327, + 199032864766430, + 281474976710655, + 398065729532860, + 562949953421311, + 796131459065721, + 1125899906842623, + 1592262918131443, + 2251799813685247, + 3184525836262886, + 4503599627370495, + 6369051672525772, + 9007199254740991, + 12738103345051545, + 18014398509481983, + 25476206690103090, + 36028797018963967, + 50952413380206180, + 72057594037927935, + 101904826760412361, + 144115188075855871, + 203809653520824722, + 288230376151711743, + 407619307041649444, + 576460752303423487, + 815238614083298888, + 1152921504606846975, + 1630477228166597776, + 2305843009213693951, + 3260954456333195553, + 4611686018427387903, + 6521908912666391106, + 9223372036854775807, + 13043817825332782212, + 18446744073709551615, + 26087635650665564424, + 36893488147419103231, + 52175271301331128849, + 73786976294838206463, + 104350542602662257698, + 147573952589676412927, + 208701085205324515397, + 295147905179352825855, + 417402170410649030795, + 590295810358705651711, + 834804340821298061590, + 1180591620717411303423, + 1669608681642596123180, + 2361183241434822606847, + 3339217363285192246361, + 4722366482869645213695, + 6678434726570384492722, + 9444732965739290427391, + 13356869453140768985445, + 18889465931478580854783, + 26713738906281537970891, + 37778931862957161709567, + 53427477812563075941783, + 75557863725914323419135, + 106854955625126151883567, + 151115727451828646838271, + 213709911250252303767135, + 302231454903657293676543, + 427419822500504607534270, + 604462909807314587353087, + 854839645001009215068541, + 1208925819614629174706175, + 1709679290002018430137083, + 2417851639229258349412351, + 3419358580004036860274166, + 4835703278458516698824703, + 6838717160008073720548332, + 9671406556917033397649407, + 13677434320016147441096664, + 19342813113834066795298815, + 27354868640032294882193329, + 38685626227668133590597631, + 54709737280064589764386658, + 77371252455336267181195263, + 109419474560129179528773316, + 154742504910672534362390527, + 218838949120258359057546633, + 309485009821345068724781055, + 437677898240516718115093267, + 618970019642690137449562111, + 875355796481033436230186534, + 1237940039285380274899124223, + 1750711592962066872460373069, + 2475880078570760549798248447, + 3501423185924133744920746139, + 4951760157141521099596496895, + 7002846371848267489841492278, + 9903520314283042199192993791, + 14005692743696534979682984556, + 19807040628566084398385987583, + 28011385487393069959365969113, + 39614081257132168796771975167, + 56022770974786139918731938227, + 79228162514264337593543950335, + 112045541949572279837463876454, + 158456325028528675187087900671, + 224091083899144559674927752909, + 316912650057057350374175801343, + 448182167798289119349855505819, + 633825300114114700748351602687, + 896364335596578238699711011639, + 1267650600228229401496703205375, + 1792728671193156477399422023278, + 2535301200456458802993406410751, + 3585457342386312954798844046557, + 5070602400912917605986812821503, + 7170914684772625909597688093114, + 10141204801825835211973625643007, + 14341829369545251819195376186229, + 20282409603651670423947251286015, + 28683658739090503638390752372458, + 40564819207303340847894502572031, + 57367317478181007276781504744917, + 81129638414606681695789005144063, + 114734634956362014553563009489834, + 162259276829213363391578010288127, + 229469269912724029107126018979668, + 324518553658426726783156020576255, + 458938539825448058214252037959337, + 649037107316853453566312041152511, + 917877079650896116428504075918674, + 1298074214633706907132624082305023, + 1835754159301792232857008151837349, + 2596148429267413814265248164610047, + 3671508318603584465714016303674698, + 5192296858534827628530496329220095, + 7343016637207168931428032607349397, + 10384593717069655257060992658440191, + 14686033274414337862856065214698794, + 20769187434139310514121985316880383, + 29372066548828675725712130429397589, + 41538374868278621028243970633760767, + 58744133097657351451424260858795179, + 83076749736557242056487941267521535, + 117488266195314702902848521717590359, + 166153499473114484112975882535043071, + 234976532390629405805697043435180719, + 332306998946228968225951765070086143, + 469953064781258811611394086870361439, + 664613997892457936451903530140172287, + 939906129562517623222788173740722878, + 1329227995784915872903807060280344575, + 1879812259125035246445576347481445757, + 2658455991569831745807614120560689151, + 3759624518250070492891152694962891514, + 5316911983139663491615228241121378303, + 7519249036500140985782305389925783028, + 10633823966279326983230456482242756607, + 15038498073000281971564610779851566057, + 21267647932558653966460912964485513215, + 30076996146000563943129221559703132115, + 42535295865117307932921825928971026431, + 60153992292001127886258443119406264231, + 85070591730234615865843651857942052863, + 120307984584002255772516886238812528463, + 170141183460469231731687303715884105727, + 240615969168004511545033772477625056927, + 340282366920938463463374607431768211455 +] + +def seedOf (i : Fin 256) : Nat := + 1 <<< ((i.val + 1) / 2) + +def loOf (i : Fin 256) : Nat := + loTable[i.val]! + +def hiOf (i : Fin 256) : Nat := + hiTable[i.val]! + +def maxAbs (i : Fin 256) : Nat := + max (seedOf i - loOf i) (hiOf i - seedOf i) + +def d1 (i : Fin 256) : Nat := + (maxAbs i * maxAbs i + 2 * hiOf i) / (2 * seedOf i) + +def nextD (lo d : Nat) : Nat := + d * d / (2 * lo) + 1 + +def d2 (i : Fin 256) : Nat := + nextD (loOf i) (d1 i) + +def d3 (i : Fin 256) : Nat := + nextD (loOf i) (d2 i) + +def d4 (i : Fin 256) : Nat := + nextD (loOf i) (d3 i) + +def d5 (i : Fin 256) : Nat := + nextD (loOf i) (d4 i) + +def d6 (i : Fin 256) : Nat := + nextD (loOf i) (d5 i) + +def loSqHolds (i : Fin 256) : Bool := + loOf i * loOf i ≤ 2 ^ i.val + +def hiSuccSqHolds (i : Fin 256) : Bool := + 2 ^ (i.val + 1) ≤ (hiOf i + 1) * (hiOf i + 1) + +def certHolds (i : Fin 256) : Bool := + let lo := loOf i + let d1v := d1 i + let d2v := d2 i + let d3v := d3 i + let d4v := d4 i + let d5v := d5 i + let d6v := d6 i + lo > 0 && + d1v <= lo && + d2v <= lo && + d3v <= lo && + d4v <= lo && + d5v <= lo && + d6v <= 1 + +theorem all_octave_certs_pass : ∀ i : Fin 256, certHolds i = true := by + native_decide + +theorem d1_le_lo (i : Fin 256) : d1 i ≤ loOf i := by + revert i + native_decide + +theorem lo_pos (i : Fin 256) : 0 < loOf i := by + revert i + native_decide + +theorem d2_le_lo (i : Fin 256) : d2 i ≤ loOf i := by + revert i + native_decide + +theorem d3_le_lo (i : Fin 256) : d3 i ≤ loOf i := by + revert i + native_decide + +theorem d4_le_lo (i : Fin 256) : d4 i ≤ loOf i := by + revert i + native_decide + +theorem d5_le_lo (i : Fin 256) : d5 i ≤ loOf i := by + revert i + native_decide + +theorem d6_le_one (i : Fin 256) : d6 i ≤ 1 := by + revert i + native_decide + +theorem lo_sq_le_pow2 (i : Fin 256) : loOf i * loOf i ≤ 2 ^ i.val := by + revert i + native_decide + +theorem pow2_succ_le_hi_succ_sq (i : Fin 256) : + 2 ^ (i.val + 1) ≤ (hiOf i + 1) * (hiOf i + 1) := by + revert i + native_decide + +end SqrtCert diff --git a/formal/sqrt/SqrtProof/SqrtProof/FloorBound.lean b/formal/sqrt/SqrtProof/SqrtProof/FloorBound.lean new file mode 100644 index 000000000..65951c0a2 --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/FloorBound.lean @@ -0,0 +1,140 @@ +/- + Lemma 1 (Floor Bound) for _sqrt convergence — Mathlib-free. + For any m with m² ≤ x, and z > 0: m ≤ (z + x / z) / 2 +-/ +import Init + +/-- One Babylonian step: ⌊(z + ⌊x/z⌋) / 2⌋. + Canonical definition used across the entire proof suite. -/ +def bstep (x z : Nat) : Nat := (z + x / z) / 2 + +-- ============================================================================ +-- Algebraic helpers +-- ============================================================================ + +/-- (a+b)² = b*(2a+b) + a² -/ +private theorem sq_decomp_1 (a b : Nat) : + (a + b) * (a + b) = b * (2 * a + b) + a * a := by + rw [Nat.add_mul, Nat.mul_add a a b, Nat.mul_add b a b] + rw [Nat.mul_add b (2 * a) b] + rw [Nat.mul_comm b (2 * a), Nat.mul_assoc 2 a b, Nat.mul_comm a b] + omega + +/-- (a+b)*(a-b) + b² = a² for b ≤ a -/ +private theorem sq_decomp_2 (a b : Nat) (h : b ≤ a) : + (a + b) * (a - b) + b * b = a * a := by + have hrecon : a - b + b = a := Nat.sub_add_cancel h + rw [Nat.mul_comm (a + b) (a - b)] + rw [Nat.mul_add (a - b) a b] + -- ((a-b)*a + (a-b)*b) + b*b = a*a + rw [Nat.add_assoc] + -- (a-b)*a + ((a-b)*b + b*b) = a*a + rw [← Nat.add_mul (a - b) b b, hrecon] + -- (a-b)*a + a*b = a*a + rw [Nat.mul_comm (a - b) a, ← Nat.mul_add a (a - b) b, hrecon] + +-- ============================================================================ +-- Core inequality: z * (2*m - z) ≤ m * m +-- ============================================================================ + +theorem sq_identity_le (z m : Nat) (h : z ≤ m) : + z * (2 * m - z) + (m - z) * (m - z) = m * m := by + have : 2 * m - z = 2 * (m - z) + z := by omega + rw [this, ← sq_decomp_1 (m - z) z, Nat.sub_add_cancel h] + +theorem sq_identity_ge (z m : Nat) (h1 : m ≤ z) (h2 : z ≤ 2 * m) : + z * (2 * m - z) + (z - m) * (z - m) = m * m := by + have key := sq_decomp_2 m (z - m) (by omega) + have h3 : m + (z - m) = z := by omega + have h4 : m - (z - m) = 2 * m - z := by omega + rw [h3, h4] at key; exact key + +theorem mul_two_sub_le_sq (z m : Nat) : z * (2 * m - z) ≤ m * m := by + by_cases h : z ≤ m + · have := sq_identity_le z m h; omega + · simp only [Nat.not_le] at h + by_cases h2 : z ≤ 2 * m + · have := sq_identity_ge z m (Nat.le_of_lt h) h2; omega + · simp only [Nat.not_le] at h2 + simp [Nat.sub_eq_zero_of_le (Nat.le_of_lt h2)] + +-- ============================================================================ +-- Division bound +-- ============================================================================ + +theorem two_mul_le_add_div_sq (m z : Nat) (hz : 0 < z) : + 2 * m ≤ z + m * m / z := by + suffices h : 2 * m - z ≤ m * m / z by omega + rw [Nat.le_div_iff_mul_le hz, Nat.mul_comm] + exact mul_two_sub_le_sq z m + +-- ============================================================================ +-- MAIN THEOREM: Lemma 1 (Floor Bound) +-- ============================================================================ + +/-- +**Lemma 1 (Floor Bound).** + +For any `m` with `m * m ≤ x`, and `z > 0`: + m ≤ (z + x / z) / 2 + +A single truncated Babylonian step never undershoots any `m` with `m² ≤ x`. +-/ +theorem babylon_step_floor_bound (x z m : Nat) (hz : 0 < z) (hm : m * m ≤ x) : + m ≤ (z + x / z) / 2 := by + rw [Nat.le_div_iff_mul_le (by omega : (0 : Nat) < 2)] + have h_mono : m * m / z ≤ x / z := Nat.div_le_div_right hm + have h_core := two_mul_le_add_div_sq m z hz + omega + +-- ============================================================================ +-- Lemma 2: Absorbing set {m, m+1} +-- ============================================================================ + +/-- (m+1)² = m² + 2m + 1 -/ +private theorem succ_sq (m : Nat) : + (m + 1) * (m + 1) = m * m + 2 * m + 1 := by + rw [sq_decomp_1 m 1, Nat.one_mul]; omega + +/-- (m-1)*(m+1) + 1 = m*m -/ +private theorem pred_succ_sq (m : Nat) (hm : 0 < m) : + (m - 1) * (m + 1) + 1 = m * m := by + -- sq_decomp_2 m 1: (m+1)*(m-1) + 1*1 = m*m + have key := sq_decomp_2 m 1 (by omega) + rw [Nat.mul_comm (m + 1) (m - 1), Nat.mul_one] at key + -- key: (m-1)*(m+1) + 1 = m*m + exact key + +/-- From z = m+1, one step gives m. -/ +theorem babylon_from_ceil (x m : Nat) (hm : 0 < m) + (hlo : m * m ≤ x) (hhi : x < (m + 1) * (m + 1)) : + (m + 1 + x / (m + 1)) / 2 = m := by + have hmp : 0 < m + 1 := by omega + -- x/(m+1) ≤ m: since x < (m+1)², x/(m+1) < m+1, so x/(m+1) ≤ m + have hd_hi : x / (m + 1) ≤ m := by + have : x / (m + 1) < m + 1 := Nat.div_lt_of_lt_mul hhi + omega + -- x/(m+1) ≥ m-1 + have hd_lo : m - 1 ≤ x / (m + 1) := by + rw [Nat.le_div_iff_mul_le hmp] + have := pred_succ_sq m hm; omega + omega + +/-- From z = m, one step gives m or m+1. -/ +theorem babylon_from_floor (x m : Nat) (hm : 0 < m) + (hlo : m * m ≤ x) (hhi : x < (m + 1) * (m + 1)) : + let z' := (m + x / m) / 2 + z' = m ∨ z' = m + 1 := by + simp only + -- x/m ≥ m + have hd_lo : m ≤ x / m := by + rw [Nat.le_div_iff_mul_le hm]; exact hlo + -- x/m ≤ m+2: x < (m+1)² = m²+2m+1, so x ≤ m²+2m = (m+2)*m + have hd_hi : x / m ≤ m + 2 := by + have hsq := succ_sq m + have hx_le : x ≤ m * m + 2 * m := by omega + calc x / m + ≤ (m * m + 2 * m) / m := Nat.div_le_div_right hx_le + _ = (m + 2) * m / m := by rw [Nat.add_mul] + _ = m + 2 := Nat.mul_div_cancel (m + 2) hm + omega diff --git a/formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtSpec.lean b/formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtSpec.lean new file mode 100644 index 000000000..b844c016a --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtSpec.lean @@ -0,0 +1,990 @@ +import Init +import SqrtProof.GeneratedSqrtModel +import SqrtProof.SqrtCorrect +import SqrtProof.CertifiedChain + +namespace SqrtGeneratedModel + +open SqrtGeneratedModel +open SqrtCertified +open SqrtCert + +private theorem normStep_eq_bstep (x z : Nat) : + normShr 1 (normAdd z (normDiv x z)) = bstep x z := by + simp [normShr, normAdd, normDiv, bstep] + +private theorem normSeed_eq_sqrtSeed_of_pos + (x : Nat) (hx : 0 < x) (hx256 : x < 2 ^ 256) : + normShl (normShr 1 (normSub 256 (normClz x))) 1 = sqrtSeed x := by + unfold normShl normShr normSub normClz sqrtSeed + simp [Nat.ne_of_gt hx] + have hlog : Nat.log2 x < 256 := (Nat.log2_lt (Nat.ne_of_gt hx)).2 hx256 + have hlogle : Nat.log2 x ≤ 255 := by omega + congr 1 + omega + +private theorem model_sqrt_zero : model_sqrt 0 = 0 := by + simp [model_sqrt, normShl, normShr, normSub, normClz, normAdd, normDiv] + +private theorem word_mod_gt_256 : 256 < WORD_MOD := by + unfold WORD_MOD + decide + +private theorem u256_eq_of_lt (x : Nat) (hx : x < WORD_MOD) : u256 x = x := by + unfold u256 + exact Nat.mod_eq_of_lt hx + +private theorem evmClz_eq_normClz_of_u256 (x : Nat) (hx : x < WORD_MOD) : + evmClz x = normClz x := by + unfold evmClz normClz + simp [u256_eq_of_lt x hx] + +private theorem normClz_le_256 (x : Nat) : normClz x ≤ 256 := by + unfold normClz + split <;> omega + +private theorem evmSub_eq_normSub_of_le + (a b : Nat) (ha : a < WORD_MOD) (hb : b ≤ a) : + evmSub a b = normSub a b := by + have hb' : b < WORD_MOD := Nat.lt_of_le_of_lt hb ha + have hab' : a - b < WORD_MOD := Nat.lt_of_le_of_lt (Nat.sub_le a b) ha + unfold evmSub normSub + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb'] + have hsplit : a + WORD_MOD - b = WORD_MOD + (a - b) := by omega + unfold u256 + rw [hsplit, Nat.add_mod, Nat.mod_eq_zero_of_dvd (Nat.dvd_refl WORD_MOD), Nat.zero_add] + simp [Nat.mod_eq_of_lt hab'] + +private theorem evmDiv_eq_normDiv_of_u256 + (x z : Nat) (hx : x < WORD_MOD) (hz : z < WORD_MOD) : + evmDiv x z = normDiv x z := by + by_cases hz0 : z = 0 + · subst hz0 + unfold evmDiv normDiv u256 + simp + · unfold evmDiv normDiv + rw [u256_eq_of_lt x hx, u256_eq_of_lt z hz] + simp [hz0] + +private theorem evmAdd_eq_normAdd_of_no_overflow + (a b : Nat) + (ha : a < WORD_MOD) + (hb : b < WORD_MOD) + (hab : a + b < WORD_MOD) : + evmAdd a b = normAdd a b := by + unfold evmAdd normAdd + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb, u256_eq_of_lt (a + b) hab] + +private theorem evmLt_eq_normLt_of_u256 + (a b : Nat) + (ha : a < WORD_MOD) + (hb : b < WORD_MOD) : + evmLt a b = normLt a b := by + unfold evmLt normLt + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb] + +private theorem evmGt_eq_normGt_of_u256 + (a b : Nat) + (ha : a < WORD_MOD) + (hb : b < WORD_MOD) : + evmGt a b = normGt a b := by + unfold evmGt normGt + simp [u256_eq_of_lt a ha, u256_eq_of_lt b hb] + +private theorem evmShr_eq_normShr_of_u256 + (s v : Nat) + (hs : s < 256) + (hv : v < WORD_MOD) : + evmShr s v = normShr s v := by + unfold evmShr normShr + have hs' : s < WORD_MOD := Nat.lt_of_lt_of_le hs (Nat.le_of_lt word_mod_gt_256) + simp [u256_eq_of_lt s hs', u256_eq_of_lt v hv, hs] + +private theorem evmShl_eq_normShl_of_safe + (s v : Nat) + (hs : s < 256) + (hv : v < WORD_MOD) + (hvs : v * 2 ^ s < WORD_MOD) : + evmShl s v = normShl s v := by + unfold evmShl normShl + have hs' : s < WORD_MOD := Nat.lt_of_lt_of_le hs (Nat.le_of_lt word_mod_gt_256) + simp [u256_eq_of_lt s hs', u256_eq_of_lt v hv, hs, Nat.shiftLeft_eq] + exact u256_eq_of_lt (v * 2 ^ s) hvs + +private theorem two_pow_lt_word (n : Nat) (hn : n < 256) : + 2 ^ n < WORD_MOD := by + unfold WORD_MOD + exact Nat.pow_lt_pow_right (by decide : 1 < (2 : Nat)) hn + +private theorem zero_lt_word : (0 : Nat) < WORD_MOD := by + unfold WORD_MOD + decide + +private theorem one_lt_word : (1 : Nat) < WORD_MOD := by + unfold WORD_MOD + decide + +private theorem pow128_plus_one_lt_word : 2 ^ 128 + 1 < WORD_MOD := by + unfold WORD_MOD + decide + +private theorem evmLt_le_one (a b : Nat) : evmLt a b ≤ 1 := by + unfold evmLt + split <;> omega + +private theorem seed_evm_eq_norm (x : Nat) (hx : x < WORD_MOD) : + evmShl (evmShr 1 (evmSub 256 (evmClz x))) 1 = + normShl (normShr 1 (normSub 256 (normClz x))) 1 := by + have hclz : evmClz x = normClz x := evmClz_eq_normClz_of_u256 x hx + have hclzLe : normClz x ≤ 256 := normClz_le_256 x + have hsub : + evmSub 256 (evmClz x) = normSub 256 (normClz x) := by + have h256 : 256 < WORD_MOD := word_mod_gt_256 + simpa [hclz] using + (evmSub_eq_normSub_of_le 256 (normClz x) h256 hclzLe) + have hsubLt : normSub 256 (normClz x) < WORD_MOD := by + have hle : normSub 256 (normClz x) ≤ 256 := by + unfold normSub + exact Nat.sub_le _ _ + exact Nat.lt_of_le_of_lt hle word_mod_gt_256 + have hshr : + evmShr 1 (evmSub 256 (evmClz x)) = + normShr 1 (normSub 256 (normClz x)) := by + have h1 : (1 : Nat) < 256 := by decide + simpa [hsub] using + (evmShr_eq_normShr_of_u256 1 (normSub 256 (normClz x)) h1 hsubLt) + have hsLt256 : normShr 1 (normSub 256 (normClz x)) < 256 := by + unfold normShr + have hle : normSub 256 (normClz x) ≤ 256 := by + unfold normSub + exact Nat.sub_le _ _ + have hdiv : normSub 256 (normClz x) / 2 ^ 1 ≤ 256 / 2 ^ 1 := Nat.div_le_div_right hle + have hdiv' : normSub 256 (normClz x) / 2 ^ 1 ≤ 128 := by simpa using hdiv + omega + have hsLtWord : normShr 1 (normSub 256 (normClz x)) < WORD_MOD := + Nat.lt_of_lt_of_le hsLt256 (Nat.le_of_lt word_mod_gt_256) + have hsafeMul : + 1 * 2 ^ (normShr 1 (normSub 256 (normClz x))) < WORD_MOD := by + simpa [Nat.one_mul] using two_pow_lt_word (normShr 1 (normSub 256 (normClz x))) hsLt256 + calc + evmShl (evmShr 1 (evmSub 256 (evmClz x))) 1 + = evmShl (normShr 1 (normSub 256 (normClz x))) 1 := by simp [hshr] + _ = normShl (normShr 1 (normSub 256 (normClz x))) 1 := by + have h1word : 1 < WORD_MOD := by + unfold WORD_MOD + decide + simpa [Nat.one_mul] using + (evmShl_eq_normShl_of_safe + (normShr 1 (normSub 256 (normClz x))) 1 hsLt256 h1word hsafeMul) + +private theorem step_evm_eq_norm_of_safe + (x z : Nat) + (hx : x < WORD_MOD) + (_hzPos : 0 < z) + (hz : z < WORD_MOD) + (hsum : z + x / z < WORD_MOD) : + evmShr 1 (evmAdd z (evmDiv x z)) = normShr 1 (normAdd z (normDiv x z)) := by + have hdivLt : x / z < WORD_MOD := Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hx + have hdiv : evmDiv x z = normDiv x z := evmDiv_eq_normDiv_of_u256 x z hx hz + have hadd : evmAdd z (evmDiv x z) = normAdd z (normDiv x z) := by + simpa [hdiv] using evmAdd_eq_normAdd_of_no_overflow z (x / z) hz hdivLt hsum + have hsumLt : normAdd z (normDiv x z) < WORD_MOD := by + simpa [normAdd, normDiv] using hsum + have h1 : (1 : Nat) < 256 := by decide + calc + evmShr 1 (evmAdd z (evmDiv x z)) + = evmShr 1 (normAdd z (normDiv x z)) := by simp [hadd] + _ = normShr 1 (normAdd z (normDiv x z)) := by + simpa using evmShr_eq_normShr_of_u256 1 (normAdd z (normDiv x z)) h1 hsumLt + +private theorem m_lt_pow128_of_u256 + (m x : Nat) + (hmlo : m * m ≤ x) + (hx : x < WORD_MOD) : + m < 2 ^ 128 := by + by_cases hm128 : m < 2 ^ 128 + · exact hm128 + · have hmGe : 2 ^ 128 ≤ m := Nat.le_of_not_lt hm128 + have hmSqGe : 2 ^ 256 ≤ m * m := by + have hpow : 2 ^ 256 = (2 ^ 128) * (2 ^ 128) := by + calc + 2 ^ 256 = 2 ^ (128 + 128) := by decide + _ = (2 ^ 128) * (2 ^ 128) := by rw [Nat.pow_add] + have hmul : (2 ^ 128) * (2 ^ 128) ≤ m * m := Nat.mul_le_mul hmGe hmGe + simpa [hpow] using hmul + have hxGe : 2 ^ 256 ≤ x := Nat.le_trans hmSqGe hmlo + exact False.elim ((Nat.not_lt_of_ge hxGe) hx) + +private theorem x_div_m_le_m_plus_two + (x m : Nat) + (hm : 0 < m) + (hmhi : x < (m + 1) * (m + 1)) : + x / m ≤ m + 2 := by + have hmhi' : x < m * m + 2 * m + 1 := by + have hsq : (m + 1) * (m + 1) = m * m + 2 * m + 1 := by + rw [Nat.add_mul, Nat.mul_add, Nat.mul_one, Nat.one_mul] + omega + simpa [hsq] using hmhi + have hmhi'' : x < (m * m + 2 * m) + 1 := by omega + have hx_le : x ≤ m * m + 2 * m := Nat.lt_succ_iff.mp hmhi'' + calc + x / m ≤ (m * m + 2 * m) / m := Nat.div_le_div_right hx_le + _ = (m + 2) * m / m := by rw [Nat.add_mul] + _ = m + 2 := Nat.mul_div_cancel (m + 2) hm + +private theorem sum_lt_word_of_cert + (x m z d : Nat) + (hx : x < WORD_MOD) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hmz : m ≤ z) + (hzd : z - m ≤ d) + (hdm : d ≤ m) : + z + x / z < WORD_MOD := by + have hdiv_z_m : x / z ≤ x / m := Nat.div_le_div_left hmz hm + have hdiv_m : x / m ≤ m + 2 := x_div_m_le_m_plus_two x m hm hmhi + have hdiv : x / z ≤ m + 2 := Nat.le_trans hdiv_z_m hdiv_m + have hz_le_md : z ≤ d + m := (Nat.sub_le_iff_le_add).1 hzd + have hz_le_2m : z ≤ 2 * m := by omega + have hsum_le : z + x / z ≤ 3 * m + 2 := by omega + have hm128 : m < 2 ^ 128 := m_lt_pow128_of_u256 m x hmlo hx + have hsum_lt_const : z + x / z < 3 * (2 ^ 128) + 2 := by omega + have hconst : 3 * (2 ^ 128) + 2 < WORD_MOD := by + unfold WORD_MOD + decide + exact Nat.lt_trans hsum_lt_const hconst + +private theorem seed_sum_lt_word + (i : Fin 256) (x : Nat) + (hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1)) : + seedOf i + x / seedOf i < WORD_MOD := by + have hsPos : 0 < seedOf i := by + have hpow : 0 < (2 : Nat) ^ ((i.val + 1) / 2) := Nat.pow_pos (by decide : 0 < (2 : Nat)) + simpa [seedOf, Nat.shiftLeft_eq, Nat.one_mul] using hpow + have hk_le : (i.val + 1) / 2 ≤ 128 := by omega + have hz_le : seedOf i ≤ 2 ^ 128 := by + unfold seedOf + rw [Nat.shiftLeft_eq, Nat.one_mul] + exact Nat.pow_le_pow_right (by decide : (2 : Nat) > 0) hk_le + have hExp : i.val + 1 ≤ 2 * ((i.val + 1) / 2) + 1 := by omega + have hPowLe : 2 ^ (i.val + 1) ≤ 2 ^ (2 * ((i.val + 1) / 2) + 1) := + Nat.pow_le_pow_right (by decide : (2 : Nat) > 0) hExp + have hPowMul : 2 ^ (2 * ((i.val + 1) / 2) + 1) = 2 * seedOf i * seedOf i := by + calc + 2 ^ (2 * ((i.val + 1) / 2) + 1) = 2 ^ (2 * ((i.val + 1) / 2)) * 2 := by rw [Nat.pow_add] + _ = (2 ^ ((i.val + 1) / 2) * 2 ^ ((i.val + 1) / 2)) * 2 := by + rw [show 2 * ((i.val + 1) / 2) = ((i.val + 1) / 2) + ((i.val + 1) / 2) by omega, Nat.pow_add] + _ = 2 * seedOf i * seedOf i := by + unfold seedOf + simp [Nat.shiftLeft_eq, Nat.one_mul, Nat.mul_comm, Nat.mul_left_comm] + have hxmul : x < 2 * seedOf i * seedOf i := by + exact Nat.lt_of_lt_of_le hOct.2 (by simpa [hPowMul] using hPowLe) + have hdiv : x / seedOf i < 2 * seedOf i := by + apply (Nat.div_lt_iff_lt_mul hsPos).2 + simpa [Nat.mul_assoc, Nat.mul_comm, Nat.mul_left_comm] using hxmul + have hsum_lt : seedOf i + x / seedOf i < seedOf i + 2 * seedOf i := by omega + have hsum_le : seedOf i + 2 * seedOf i ≤ 3 * (2 ^ 128) := by omega + have hconst : 3 * (2 ^ 128) < WORD_MOD := by + unfold WORD_MOD + decide + exact Nat.lt_of_lt_of_le (Nat.lt_of_lt_of_le hsum_lt hsum_le) (Nat.le_of_lt hconst) + +private theorem normLt_div_le (x z : Nat) : + normLt (normDiv x z) z ≤ z := by + by_cases hz0 : z = 0 + · simp [normLt, normDiv, hz0] + · have hzPos : 0 < z := Nat.pos_of_ne_zero hz0 + have h1 : 1 ≤ z := Nat.succ_le_of_lt hzPos + by_cases hlt : x / z < z + · simp [normLt, normDiv, hlt, h1] + · simp [normLt, normDiv, hlt] + +private theorem floor_correction_norm_eq_if (x z : Nat) : + normSub z (normLt (normDiv x z) z) = + (if z = 0 then 0 else if x / z < z then z - 1 else z) := by + by_cases hz0 : z = 0 + · subst hz0 + simp [normSub, normLt, normDiv] + · by_cases hlt : x / z < z + · simp [normSub, normLt, normDiv, hz0, hlt] + · simp [normSub, normLt, normDiv, hz0, hlt] + +theorem model_sqrt_evm_eq_model_sqrt + (x : Nat) + (hx256 : x < WORD_MOD) : + model_sqrt_evm x = model_sqrt x := by + by_cases hx0 : x = 0 + · subst hx0 + decide + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + let i : Fin 256 := ⟨Nat.log2 x, (Nat.log2_lt (Nat.ne_of_gt hx)).2 hx256⟩ + let m := natSqrt x + have hmlo : m * m ≤ x := by simpa [m] using natSqrt_sq_le x + have hmhi : x < (m + 1) * (m + 1) := by simpa [m] using natSqrt_lt_succ_sq x + have hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1) := by + have hlog : 2 ^ Nat.log2 x ≤ x ∧ x < 2 ^ (Nat.log2 x + 1) := + (Nat.log2_eq_iff (Nat.ne_of_gt hx)).1 rfl + simpa [i] + have hm : 0 < m := by + by_cases hm0 : m = 0 + · + have hx1 : 1 ≤ x := Nat.succ_le_of_lt hx + have hlt1 : x < 1 := by + have : x < (0 + 1) * (0 + 1) := by simpa [hm0] using hmhi + simpa using this + exact False.elim ((Nat.not_lt_of_ge hx1) hlt1) + · exact Nat.pos_of_ne_zero hm0 + have hseedOf : sqrtSeed x = seedOf i := sqrtSeed_eq_seedOf_of_octave i x hOct + have hseedNorm : + normShl (normShr 1 (normSub 256 (normClz x))) 1 = seedOf i := by + exact (normSeed_eq_sqrtSeed_of_pos x hx hx256).trans hseedOf + have hseedEvm : + evmShl (evmShr 1 (evmSub 256 (evmClz x))) 1 = seedOf i := by + exact (seed_evm_eq_norm x hx256).trans hseedNorm + let z0 := seedOf i + let z1 := bstep x z0 + let z2 := bstep x z1 + let z3 := bstep x z2 + let z4 := bstep x z3 + let z5 := bstep x z4 + let z6 := bstep x z5 + have hsPos : 0 < z0 := by + dsimp [z0] + have hpow : 0 < (2 : Nat) ^ ((i.val + 1) / 2) := Nat.pow_pos (by decide : 0 < (2 : Nat)) + simpa [seedOf, Nat.shiftLeft_eq, Nat.one_mul] using hpow + have hmz1 : m ≤ z1 := by + dsimp [z1, z0] + exact babylon_step_floor_bound x (seedOf i) m hsPos hmlo + have hz1Pos : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hmz2 : m ≤ z2 := by + dsimp [z2] + exact babylon_step_floor_bound x z1 m hz1Pos hmlo + have hz2Pos : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hmz3 : m ≤ z3 := by + dsimp [z3] + exact babylon_step_floor_bound x z2 m hz2Pos hmlo + have hz3Pos : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hmz4 : m ≤ z4 := by + dsimp [z4] + exact babylon_step_floor_bound x z3 m hz3Pos hmlo + have hz4Pos : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hmz5 : m ≤ z5 := by + dsimp [z5] + exact babylon_step_floor_bound x z4 m hz4Pos hmlo + have hz5Pos : 0 < z5 := Nat.lt_of_lt_of_le hm hmz5 + have hinterval : loOf i ≤ m ∧ m ≤ hiOf i := m_within_cert_interval i x m hmlo hmhi hOct + have hrun5 := run5_error_bounds i x m hm hmlo hmhi hinterval.1 hinterval.2 + have hd1 : z1 - m ≤ d1 i := by simpa [z1, z2, z3, z4, z5] using hrun5.1 + have hd2 : z2 - m ≤ d2 i := by simpa [z1, z2, z3, z4, z5] using hrun5.2.1 + have hd3 : z3 - m ≤ d3 i := by simpa [z1, z2, z3, z4, z5] using hrun5.2.2.1 + have hd4 : z4 - m ≤ d4 i := by simpa [z1, z2, z3, z4, z5] using hrun5.2.2.2.1 + have hd5 : z5 - m ≤ d5 i := by simpa [z1, z2, z3, z4, z5] using hrun5.2.2.2.2 + have hd1m : d1 i ≤ m := Nat.le_trans (d1_le_lo i) hinterval.1 + have hd2m : d2 i ≤ m := Nat.le_trans (d2_le_lo i) hinterval.1 + have hd3m : d3 i ≤ m := Nat.le_trans (d3_le_lo i) hinterval.1 + have hd4m : d4 i ≤ m := Nat.le_trans (d4_le_lo i) hinterval.1 + have hd5m : d5 i ≤ m := Nat.le_trans (d5_le_lo i) hinterval.1 + have hsum0 : z0 + x / z0 < WORD_MOD := by + simpa [z0] using seed_sum_lt_word i x hOct + have hsum1 : z1 + x / z1 < WORD_MOD := sum_lt_word_of_cert x m z1 (d1 i) hx256 hm hmlo hmhi hmz1 hd1 hd1m + have hsum2 : z2 + x / z2 < WORD_MOD := sum_lt_word_of_cert x m z2 (d2 i) hx256 hm hmlo hmhi hmz2 hd2 hd2m + have hsum3 : z3 + x / z3 < WORD_MOD := sum_lt_word_of_cert x m z3 (d3 i) hx256 hm hmlo hmhi hmz3 hd3 hd3m + have hsum4 : z4 + x / z4 < WORD_MOD := sum_lt_word_of_cert x m z4 (d4 i) hx256 hm hmlo hmhi hmz4 hd4 hd4m + have hsum5 : z5 + x / z5 < WORD_MOD := sum_lt_word_of_cert x m z5 (d5 i) hx256 hm hmlo hmhi hmz5 hd5 hd5m + have hz0 : z0 < WORD_MOD := Nat.lt_of_le_of_lt (Nat.le_add_right z0 (x / z0)) hsum0 + have hz1 : z1 < WORD_MOD := Nat.lt_of_le_of_lt (Nat.le_add_right z1 (x / z1)) hsum1 + have hz2 : z2 < WORD_MOD := Nat.lt_of_le_of_lt (Nat.le_add_right z2 (x / z2)) hsum2 + have hz3 : z3 < WORD_MOD := Nat.lt_of_le_of_lt (Nat.le_add_right z3 (x / z3)) hsum3 + have hz4 : z4 < WORD_MOD := Nat.lt_of_le_of_lt (Nat.le_add_right z4 (x / z4)) hsum4 + have hz5 : z5 < WORD_MOD := Nat.lt_of_le_of_lt (Nat.le_add_right z5 (x / z5)) hsum5 + have hstep1 : evmShr 1 (evmAdd z0 (evmDiv x z0)) = z1 := by + have h := step_evm_eq_norm_of_safe x z0 hx256 hsPos hz0 hsum0 + simpa [z1, normStep_eq_bstep] using h + have hstep2 : evmShr 1 (evmAdd z1 (evmDiv x z1)) = z2 := by + have h := step_evm_eq_norm_of_safe x z1 hx256 hz1Pos hz1 hsum1 + simpa [z2, normStep_eq_bstep] using h + have hstep3 : evmShr 1 (evmAdd z2 (evmDiv x z2)) = z3 := by + have h := step_evm_eq_norm_of_safe x z2 hx256 hz2Pos hz2 hsum2 + simpa [z3, normStep_eq_bstep] using h + have hstep4 : evmShr 1 (evmAdd z3 (evmDiv x z3)) = z4 := by + have h := step_evm_eq_norm_of_safe x z3 hx256 hz3Pos hz3 hsum3 + simpa [z4, normStep_eq_bstep] using h + have hstep5 : evmShr 1 (evmAdd z4 (evmDiv x z4)) = z5 := by + have h := step_evm_eq_norm_of_safe x z4 hx256 hz4Pos hz4 hsum4 + simpa [z5, normStep_eq_bstep] using h + have hstep6 : evmShr 1 (evmAdd z5 (evmDiv x z5)) = z6 := by + have h := step_evm_eq_norm_of_safe x z5 hx256 hz5Pos hz5 hsum5 + simpa [z6, normStep_eq_bstep] using h + have hxmod : u256 x = x := u256_eq_of_lt x hx256 + unfold model_sqrt_evm model_sqrt + simp [hxmod, hseedEvm, hseedNorm, z0, z1, z2, z3, z4, z5, z6, + hstep1, hstep2, hstep3, hstep4, hstep5, hstep6, normStep_eq_bstep] + +theorem model_sqrt_eq_innerSqrt (x : Nat) (hx256 : x < 2 ^ 256) : + model_sqrt x = innerSqrt x := by + by_cases hx0 : x = 0 + · subst hx0 + simp [innerSqrt, model_sqrt_zero] + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + have hseed : normShl (normShr 1 (normSub 256 (normClz x))) 1 = sqrtSeed x := + normSeed_eq_sqrtSeed_of_pos x hx hx256 + unfold model_sqrt innerSqrt + simp [Nat.ne_of_gt hx, hseed, normStep_eq_bstep] + +theorem model_sqrt_bracket_u256_all + (x : Nat) + (hx256 : x < 2 ^ 256) : + let m := natSqrt x + m ≤ model_sqrt x ∧ model_sqrt x ≤ m + 1 := by + simpa [model_sqrt_eq_innerSqrt x hx256] using innerSqrt_bracket_u256_all x hx256 + +theorem model_sqrt_evm_bracket_u256_all + (x : Nat) + (hx256 : x < 2 ^ 256) : + let m := natSqrt x + m ≤ model_sqrt_evm x ∧ model_sqrt_evm x ≤ m + 1 := by + have hxW : x < WORD_MOD := by simpa [WORD_MOD] using hx256 + simpa [model_sqrt_evm_eq_model_sqrt x hxW] using model_sqrt_bracket_u256_all x hx256 + +theorem model_sqrt_floor_eq_floorSqrt + (x : Nat) + (hx256 : x < 2 ^ 256) : + model_sqrt_floor x = floorSqrt x := by + have hinner : model_sqrt x = innerSqrt x := model_sqrt_eq_innerSqrt x hx256 + unfold model_sqrt_floor floorSqrt + simp [hinner, floor_correction_norm_eq_if] + +private theorem floor_step_evm_eq_norm + (x z : Nat) + (hx : x < WORD_MOD) + (hz : z < WORD_MOD) : + evmSub z (evmLt (evmDiv x z) z) = + normSub z (normLt (normDiv x z) z) := by + have hdiv : evmDiv x z = normDiv x z := evmDiv_eq_normDiv_of_u256 x z hx hz + have hdivLt : normDiv x z < WORD_MOD := Nat.lt_of_le_of_lt (Nat.div_le_self _ _) hx + have hlt : evmLt (evmDiv x z) z = normLt (normDiv x z) z := by + simpa [hdiv] using evmLt_eq_normLt_of_u256 (normDiv x z) z hdivLt hz + have hbLe : normLt (normDiv x z) z ≤ z := normLt_div_le x z + calc + evmSub z (evmLt (evmDiv x z) z) + = evmSub z (normLt (normDiv x z) z) := by simp [hlt] + _ = normSub z (normLt (normDiv x z) z) := + evmSub_eq_normSub_of_le z (normLt (normDiv x z) z) hz hbLe + +theorem model_sqrt_floor_evm_eq_model_sqrt_floor + (x : Nat) + (hxW : x < WORD_MOD) : + model_sqrt_floor_evm x = model_sqrt_floor x := by + have hx256 : x < 2 ^ 256 := by simpa [WORD_MOD] using hxW + have hbr := model_sqrt_evm_bracket_u256_all x hx256 + have hzLe : model_sqrt_evm x ≤ natSqrt x + 1 := by simpa using hbr.2 + have hm128 : natSqrt x < 2 ^ 128 := + m_lt_pow128_of_u256 (natSqrt x) x (natSqrt_sq_le x) hxW + have hz128 : model_sqrt_evm x ≤ 2 ^ 128 := by omega + have hpow128 : 2 ^ 128 < WORD_MOD := two_pow_lt_word 128 (by decide) + have hzW : model_sqrt_evm x < WORD_MOD := Nat.lt_of_le_of_lt hz128 hpow128 + have hroot : model_sqrt_evm x = model_sqrt x := model_sqrt_evm_eq_model_sqrt x hxW + have hxmod : u256 x = x := u256_eq_of_lt x hxW + unfold model_sqrt_floor_evm model_sqrt_floor + simp [hxmod] + simpa [hroot] using floor_step_evm_eq_norm x (model_sqrt_evm x) hxW hzW + +theorem model_sqrt_floor_evm_eq_floorSqrt + (x : Nat) + (hx256 : x < 2 ^ 256) : + model_sqrt_floor_evm x = floorSqrt x := by + have hxW : x < WORD_MOD := by simpa [WORD_MOD] using hx256 + calc + model_sqrt_floor_evm x = model_sqrt_floor x := model_sqrt_floor_evm_eq_model_sqrt_floor x hxW + _ = floorSqrt x := model_sqrt_floor_eq_floorSqrt x hx256 + +/-- Specification-level model for `sqrtUp`: round `innerSqrt` upward if needed. -/ +def sqrtUpSpec (x : Nat) : Nat := + let z := innerSqrt x + if z * z < x then z + 1 else z + +private theorem model_sqrt_up_norm_eq_sqrtUpSpec + (x : Nat) + (hx256 : x < 2 ^ 256) : + model_sqrt_up x = sqrtUpSpec x := by + have hinner : model_sqrt x = innerSqrt x := model_sqrt_eq_innerSqrt x hx256 + have hsqge : innerSqrt x ≤ innerSqrt x * innerSqrt x := by + by_cases hz0 : innerSqrt x = 0 + · simp [hz0] + · have hzPos : 0 < innerSqrt x := Nat.pos_of_ne_zero hz0 + have h1 : 1 ≤ innerSqrt x := Nat.succ_le_of_lt hzPos + calc + innerSqrt x = innerSqrt x * 1 := by simp + _ ≤ innerSqrt x * innerSqrt x := Nat.mul_le_mul_left _ h1 + unfold model_sqrt_up sqrtUpSpec + by_cases hlt : innerSqrt x * innerSqrt x < x + · simp [normAdd, normMul, normLt, normGt, hinner, hlt, hsqge, Nat.add_comm] + · simp [normAdd, normMul, normLt, normGt, hinner, hlt] + +private theorem sqrtUp_step_evm_eq_spec + (x z : Nat) + (hxW : x < WORD_MOD) + (hzLe128 : z ≤ 2 ^ 128) : + evmAdd (evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z)) z = + (if z * z < x then z + 1 else z) := by + have hpow128 : 2 ^ 128 < WORD_MOD := two_pow_lt_word 128 (by decide) + have hzW : z < WORD_MOD := Nat.lt_of_le_of_lt hzLe128 hpow128 + by_cases hzMax : z = 2 ^ 128 + · have hsqEq : z * z = WORD_MOD := by + rw [hzMax] + unfold WORD_MOD + calc + (2 ^ 128) * (2 ^ 128) = 2 ^ (128 + 128) := by rw [← Nat.pow_add] + _ = 2 ^ 256 := by decide + have hmul0 : evmMul z z = 0 := by + unfold evmMul u256 + simp [hsqEq] + have hzPos : 0 < z := by + rw [hzMax] + exact Nat.two_pow_pos 128 + have hltZ1 : evmLt (evmMul z z) z = 1 := by + rw [hmul0] + have hltEq : evmLt 0 z = normLt 0 z := evmLt_eq_normLt_of_u256 0 z zero_lt_word hzW + have hnorm1 : normLt 0 z = 1 := by + unfold normLt + simp [hzPos] + exact hltEq.trans hnorm1 + have hltXLe : evmLt (evmMul z z) x ≤ 1 := evmLt_le_one (evmMul z z) x + have hltXW : evmLt (evmMul z z) x < WORD_MOD := Nat.lt_of_le_of_lt hltXLe one_lt_word + have hgt0 : evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z) = 0 := by + rw [hltZ1] + have hgtEq : + evmGt (evmLt (evmMul z z) x) 1 = normGt (evmLt (evmMul z z) x) 1 := + evmGt_eq_normGt_of_u256 (evmLt (evmMul z z) x) 1 hltXW one_lt_word + have hnorm0 : normGt (evmLt (evmMul z z) x) 1 = 0 := by + unfold normGt + have hnot : ¬ evmLt (evmMul z z) x > 1 := Nat.not_lt_of_ge hltXLe + simp [hnot] + exact hgtEq.trans hnorm0 + have hadd0 : evmAdd 0 z = z := by + have h := evmAdd_eq_normAdd_of_no_overflow 0 z zero_lt_word hzW (by simpa using hzW) + simpa [normAdd] using h + have hsqNotLt : ¬ z * z < x := by + rw [hsqEq] + exact Nat.not_lt_of_ge (Nat.le_of_lt hxW) + calc + evmAdd (evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z)) z + = evmAdd 0 z := by simp [hgt0] + _ = z := hadd0 + _ = if z * z < x then z + 1 else z := by simp [hsqNotLt] + · have hzLt : z < 2 ^ 128 := Nat.lt_of_le_of_ne hzLe128 hzMax + have hzzW : z * z < WORD_MOD := by + have hmulLe : z * z ≤ z * (2 ^ 128) := Nat.mul_le_mul_left z (Nat.le_of_lt hzLt) + have hmulLt : z * (2 ^ 128) < (2 ^ 128) * (2 ^ 128) := + Nat.mul_lt_mul_of_pos_right hzLt (Nat.two_pow_pos 128) + have hlt : z * z < (2 ^ 128) * (2 ^ 128) := Nat.lt_of_le_of_lt hmulLe hmulLt + have hpowEq : (2 ^ 128) * (2 ^ 128) = WORD_MOD := by + unfold WORD_MOD + calc + (2 ^ 128) * (2 ^ 128) = 2 ^ (128 + 128) := by rw [← Nat.pow_add] + _ = 2 ^ 256 := by decide + simpa [hpowEq] using hlt + have hmulNat : evmMul z z = z * z := by + unfold evmMul + simp [u256_eq_of_lt z hzW, u256_eq_of_lt (z * z) hzzW] + have hsqGe : z ≤ z * z := by + by_cases hz0 : z = 0 + · simp [hz0] + · have hzPos : 0 < z := Nat.pos_of_ne_zero hz0 + have h1 : 1 ≤ z := Nat.succ_le_of_lt hzPos + calc + z = z * 1 := by simp + _ ≤ z * z := Nat.mul_le_mul_left z h1 + have hltZ0 : evmLt (evmMul z z) z = 0 := by + rw [hmulNat] + unfold evmLt + have hnot : ¬ z * z < z := Nat.not_lt_of_ge hsqGe + simp [u256_eq_of_lt (z * z) hzzW, u256_eq_of_lt z hzW, hnot] + by_cases hsqx : z * z < x + · have hltX1 : evmLt (evmMul z z) x = 1 := by + rw [hmulNat] + unfold evmLt + simp [u256_eq_of_lt (z * z) hzzW, u256_eq_of_lt x hxW, hsqx] + have hgt1 : evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z) = 1 := by + rw [hltX1, hltZ0] + have hgtEq : evmGt 1 0 = normGt 1 0 := + evmGt_eq_normGt_of_u256 1 0 one_lt_word zero_lt_word + have hnorm1 : normGt 1 0 = 1 := by + unfold normGt + decide + exact hgtEq.trans hnorm1 + have hsum1 : 1 + z < WORD_MOD := by + have hle : 1 + z ≤ 1 + 2 ^ 128 := by omega + exact Nat.lt_of_le_of_lt hle pow128_plus_one_lt_word + have hadd1 : evmAdd 1 z = z + 1 := by + have h := evmAdd_eq_normAdd_of_no_overflow 1 z one_lt_word hzW hsum1 + simpa [normAdd, Nat.add_comm] using h + calc + evmAdd (evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z)) z + = evmAdd 1 z := by simp [hgt1] + _ = z + 1 := hadd1 + _ = if z * z < x then z + 1 else z := by simp [hsqx] + · have hltX0 : evmLt (evmMul z z) x = 0 := by + rw [hmulNat] + unfold evmLt + simp [u256_eq_of_lt (z * z) hzzW, u256_eq_of_lt x hxW, hsqx] + have hgt0 : evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z) = 0 := by + rw [hltX0, hltZ0] + unfold evmGt + simp + have hadd0 : evmAdd 0 z = z := by + have h := evmAdd_eq_normAdd_of_no_overflow 0 z zero_lt_word hzW (by simpa using hzW) + simpa [normAdd] using h + calc + evmAdd (evmGt (evmLt (evmMul z z) x) (evmLt (evmMul z z) z)) z + = evmAdd 0 z := by simp [hgt0] + _ = z := hadd0 + _ = if z * z < x then z + 1 else z := by simp [hsqx] + +theorem model_sqrt_up_eq_sqrtUpSpec + (x : Nat) + (hx256 : x < 2 ^ 256) : + model_sqrt_up x = sqrtUpSpec x := + model_sqrt_up_norm_eq_sqrtUpSpec x hx256 + +theorem model_sqrt_up_evm_eq_sqrtUpSpec + (x : Nat) + (hx256 : x < 2 ^ 256) : + model_sqrt_up_evm x = sqrtUpSpec x := by + have hxW : x < WORD_MOD := by simpa [WORD_MOD] using hx256 + have hbr := model_sqrt_evm_bracket_u256_all x hx256 + have hzLe : model_sqrt_evm x ≤ natSqrt x + 1 := by simpa using hbr.2 + have hm128 : natSqrt x < 2 ^ 128 := + m_lt_pow128_of_u256 (natSqrt x) x (natSqrt_sq_le x) hxW + have hzLe128 : model_sqrt_evm x ≤ 2 ^ 128 := by omega + have hroot : model_sqrt_evm x = innerSqrt x := by + exact (model_sqrt_evm_eq_model_sqrt x hxW).trans (model_sqrt_eq_innerSqrt x hx256) + have hxmod : u256 x = x := u256_eq_of_lt x hxW + unfold model_sqrt_up_evm sqrtUpSpec + simp [hxmod] + simpa [hroot] using sqrtUp_step_evm_eq_spec x (model_sqrt_evm x) hxW hzLe128 + +private theorem step_error_bound_square + (m d : Nat) + (hm : 0 < m) + (hmd : d ≤ m) : + bstep (m * m) (m + d) - m ≤ d * d / (2 * m) := by + unfold bstep + have hpos : 0 < m + d := by omega + have hsq : m * m = (m + d) * (m - d) + d * d := by + have h := sq_identity_ge (m + d) m (by omega) (by omega) + have hsub : 2 * m - (m + d) = m - d := by omega + have hdm' : (m + d) - m = d := by rw [Nat.add_sub_cancel_left] + simpa [hsub, hdm'] using h.symm + have hdiv : m * m / (m + d) = (m - d) + d * d / (m + d) := by + rw [hsq] + rw [Nat.mul_add_div hpos] + have hrewrite : + (m + d + m * m / (m + d)) / 2 - m = (d * d / (m + d)) / 2 := by + rw [hdiv] + let q := d * d / (m + d) + have htmp : (m + d + (m - d + q)) / 2 = m + q / 2 := by + have hsum : m + d + (m - d + q) = 2 * m + q := by omega + rw [hsum] + have htmp2 : (2 * m + q) / 2 = m + q / 2 := by + have hswap : 2 * m + q = q + m * 2 := by omega + rw [hswap, Nat.add_mul_div_right q m (by decide : 0 < 2)] + omega + exact htmp2 + rw [htmp, Nat.add_sub_cancel_left] + rw [hrewrite] + have hden : m ≤ m + d := by omega + have hdivLe : d * d / (m + d) ≤ d * d / m := Nat.div_le_div_left hden hm + have hhalf : (d * d / (m + d)) / 2 ≤ (d * d / m) / 2 := Nat.div_le_div_right hdivLe + have hmain : (d * d / m) / 2 = d * d / (2 * m) := by + rw [Nat.div_div_eq_div_mul, Nat.mul_comm m 2] + exact Nat.le_trans hhalf (by simp [hmain]) + +private theorem step_from_bound_square + (m lo z D : Nat) + (hm : 0 < m) + (hloPos : 0 < lo) + (hlo : lo ≤ m) + (hmz : m ≤ z) + (hzD : z - m ≤ D) + (hDlo : D ≤ lo) : + bstep (m * m) z - m ≤ D * D / (2 * lo) := by + let d := z - m + have hdEq : z = m + d := by + dsimp [d] + omega + have hdm : d ≤ m := by + dsimp [d] + omega + have hstep : bstep (m * m) (m + d) - m ≤ d * d / (2 * m) := + step_error_bound_square m d hm hdm + have hbase : bstep (m * m) z - m ≤ d * d / (2 * m) := by + simpa [hdEq] using hstep + have hdD : d ≤ D := by + simpa [d] using hzD + have hsq : d * d ≤ D * D := Nat.mul_le_mul hdD hdD + have hdiv : d * d / (2 * m) ≤ D * D / (2 * m) := Nat.div_le_div_right hsq + have hden : 2 * lo ≤ 2 * m := Nat.mul_le_mul_left 2 hlo + have hdivDen : D * D / (2 * m) ≤ D * D / (2 * lo) := + Nat.div_le_div_left hden (by omega : 0 < 2 * lo) + exact Nat.le_trans hbase (Nat.le_trans hdiv hdivDen) + +private def sqNext (lo d : Nat) : Nat := d * d / (2 * lo) + +private def sqD2 (i : Fin 256) : Nat := sqNext (loOf i) (d1 i) +private def sqD3 (i : Fin 256) : Nat := sqNext (loOf i) (sqD2 i) +private def sqD4 (i : Fin 256) : Nat := sqNext (loOf i) (sqD3 i) +private def sqD5 (i : Fin 256) : Nat := sqNext (loOf i) (sqD4 i) +private def sqD6 (i : Fin 256) : Nat := sqNext (loOf i) (sqD5 i) + +private theorem sqNext_mono_right (lo a b : Nat) (hab : a ≤ b) : + sqNext lo a ≤ sqNext lo b := by + unfold sqNext + exact Nat.div_le_div_right (Nat.mul_le_mul hab hab) + +private theorem sqNext_le_lo + (lo d : Nat) + (hlo : 0 < lo) + (hd : d ≤ lo) : + sqNext lo d ≤ lo := by + unfold sqNext + have hsq : d * d ≤ lo * lo := Nat.mul_le_mul hd hd + have hdiv : d * d / (2 * lo) ≤ lo * lo / (2 * lo) := Nat.div_le_div_right hsq + have hden : lo ≤ 2 * lo := by omega + have hdiv' : lo * lo / (2 * lo) ≤ lo * lo / lo := Nat.div_le_div_left hden hlo + have hmul : lo * lo / lo = lo := by simpa [Nat.mul_comm] using Nat.mul_div_right lo hlo + exact Nat.le_trans hdiv (by simpa [hmul] using hdiv') + +private theorem sqD2_le_lo : ∀ i : Fin 256, sqD2 i ≤ loOf i := by + intro i + unfold sqD2 + exact sqNext_le_lo (loOf i) (d1 i) (lo_pos i) (d1_le_lo i) + +private theorem sqD3_le_lo : ∀ i : Fin 256, sqD3 i ≤ loOf i := by + intro i + unfold sqD3 + exact sqNext_le_lo (loOf i) (sqD2 i) (lo_pos i) (sqD2_le_lo i) + +private theorem sqD4_le_lo : ∀ i : Fin 256, sqD4 i ≤ loOf i := by + intro i + unfold sqD4 + exact sqNext_le_lo (loOf i) (sqD3 i) (lo_pos i) (sqD3_le_lo i) + +private theorem sqD5_le_lo : ∀ i : Fin 256, sqD5 i ≤ loOf i := by + intro i + unfold sqD5 + exact sqNext_le_lo (loOf i) (sqD4 i) (lo_pos i) (sqD4_le_lo i) + +private theorem sqD2_le_d2 : ∀ i : Fin 256, sqD2 i ≤ d2 i := by + intro i + simp [sqD2, d2, sqNext, nextD] + +private theorem sqD3_le_d3 : ∀ i : Fin 256, sqD3 i ≤ d3 i := by + intro i + have hmono : sqNext (loOf i) (sqD2 i) ≤ sqNext (loOf i) (d2 i) := + sqNext_mono_right (loOf i) (sqD2 i) (d2 i) (sqD2_le_d2 i) + unfold sqD3 d3 nextD + exact Nat.le_trans hmono (Nat.le_succ _) + +private theorem sqD4_le_d4 : ∀ i : Fin 256, sqD4 i ≤ d4 i := by + intro i + have hmono : sqNext (loOf i) (sqD3 i) ≤ sqNext (loOf i) (d3 i) := + sqNext_mono_right (loOf i) (sqD3 i) (d3 i) (sqD3_le_d3 i) + unfold sqD4 d4 nextD + exact Nat.le_trans hmono (Nat.le_succ _) + +private theorem sqD5_le_d5 : ∀ i : Fin 256, sqD5 i ≤ d5 i := by + intro i + have hmono : sqNext (loOf i) (sqD4 i) ≤ sqNext (loOf i) (d4 i) := + sqNext_mono_right (loOf i) (sqD4 i) (d4 i) (sqD4_le_d4 i) + unfold sqD5 d5 nextD + exact Nat.le_trans hmono (Nat.le_succ _) + +private theorem sqD6_eq_zero : ∀ i : Fin 256, sqD6 i = 0 := by + intro i + have hsqLe : sqD6 i ≤ sqNext (loOf i) (d5 i) := by + unfold sqD6 + exact sqNext_mono_right (loOf i) (sqD5 i) (d5 i) (sqD5_le_d5 i) + have hd6le : d6 i ≤ 1 := d6_le_one i + have hd6ge : 1 ≤ d6 i := by + unfold d6 nextD + exact Nat.succ_le_succ (Nat.zero_le _) + have hd6eq : d6 i = 1 := Nat.le_antisymm hd6le hd6ge + have hsq0 : sqNext (loOf i) (d5 i) = 0 := by + have hq : sqNext (loOf i) (d5 i) + 1 = d6 i := by + simp [sqNext, d6, nextD] + omega + have hsqD6le0 : sqD6 i ≤ 0 := Nat.le_trans hsqLe (by simp [hsq0]) + exact Nat.eq_zero_of_le_zero hsqD6le0 + +private theorem innerSqrt_eq_natSqrt_of_square + (x : Nat) + (hx256 : x < 2 ^ 256) + (hsq : natSqrt x * natSqrt x = x) : + innerSqrt x = natSqrt x := by + by_cases hx0 : x = 0 + · subst hx0 + simp [innerSqrt, natSqrt] + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + let m := natSqrt x + have hmSq : m * m = x := by simpa [m] using hsq + have hmlo : m * m ≤ x := by simp [m, hmSq] + have hmhi : x < (m + 1) * (m + 1) := by simpa [m] using natSqrt_lt_succ_sq x + have hm : 0 < m := by + by_cases hm0 : m = 0 + · have hx0' : x = 0 := by simpa [m, hm0] using hmSq.symm + exact False.elim (hx0 hx0') + · exact Nat.pos_of_ne_zero hm0 + let i : Fin 256 := ⟨Nat.log2 x, (Nat.log2_lt (Nat.ne_of_gt hx)).2 (by simpa [WORD_MOD] using hx256)⟩ + have hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1) := by + have hlog : 2 ^ Nat.log2 x ≤ x ∧ x < 2 ^ (Nat.log2 x + 1) := + (Nat.log2_eq_iff (Nat.ne_of_gt hx)).1 rfl + simpa [i] + have hseed : sqrtSeed x = seedOf i := sqrtSeed_eq_seedOf_of_octave i x hOct + let z0 := seedOf i + let z1 := bstep x z0 + let z2 := bstep x z1 + let z3 := bstep x z2 + let z4 := bstep x z3 + let z5 := bstep x z4 + let z6 := bstep x z5 + have hsPos : 0 < z0 := by + dsimp [z0] + have hpow : 0 < (2 : Nat) ^ ((i.val + 1) / 2) := Nat.pow_pos (by decide : 0 < (2 : Nat)) + simpa [seedOf, Nat.shiftLeft_eq, Nat.one_mul] using hpow + have hmz1 : m ≤ z1 := by + dsimp [z1, z0] + exact babylon_step_floor_bound x (seedOf i) m hsPos hmlo + have hz1Pos : 0 < z1 := Nat.lt_of_lt_of_le hm hmz1 + have hmz2 : m ≤ z2 := by + dsimp [z2] + exact babylon_step_floor_bound x z1 m hz1Pos hmlo + have hz2Pos : 0 < z2 := Nat.lt_of_lt_of_le hm hmz2 + have hmz3 : m ≤ z3 := by + dsimp [z3] + exact babylon_step_floor_bound x z2 m hz2Pos hmlo + have hz3Pos : 0 < z3 := Nat.lt_of_lt_of_le hm hmz3 + have hmz4 : m ≤ z4 := by + dsimp [z4] + exact babylon_step_floor_bound x z3 m hz3Pos hmlo + have hz4Pos : 0 < z4 := Nat.lt_of_lt_of_le hm hmz4 + have hmz5 : m ≤ z5 := by + dsimp [z5] + exact babylon_step_floor_bound x z4 m hz4Pos hmlo + have hz5Pos : 0 < z5 := Nat.lt_of_lt_of_le hm hmz5 + have hmz6 : m ≤ z6 := by + dsimp [z6] + exact babylon_step_floor_bound x z5 m hz5Pos hmlo + have hinterval : loOf i ≤ m ∧ m ≤ hiOf i := m_within_cert_interval i x m hmlo hmhi hOct + have hrun5 := run5_error_bounds i x m hm hmlo hmhi hinterval.1 hinterval.2 + have hd1 : z1 - m ≤ d1 i := by simpa [z1, z2, z3, z4, z5] using hrun5.1 + have hd2 : z2 - m ≤ sqD2 i := by + have h := step_from_bound_square m (loOf i) z1 (d1 i) hm (lo_pos i) hinterval.1 hmz1 hd1 (d1_le_lo i) + simpa [z2, hmSq, sqD2, sqNext] using h + have hd3 : z3 - m ≤ sqD3 i := by + have h := step_from_bound_square m (loOf i) z2 (sqD2 i) hm (lo_pos i) hinterval.1 hmz2 hd2 (sqD2_le_lo i) + simpa [z3, hmSq, sqD3, sqNext] using h + have hd4 : z4 - m ≤ sqD4 i := by + have h := step_from_bound_square m (loOf i) z3 (sqD3 i) hm (lo_pos i) hinterval.1 hmz3 hd3 (sqD3_le_lo i) + simpa [z4, hmSq, sqD4, sqNext] using h + have hd5 : z5 - m ≤ sqD5 i := by + have h := step_from_bound_square m (loOf i) z4 (sqD4 i) hm (lo_pos i) hinterval.1 hmz4 hd4 (sqD4_le_lo i) + simpa [z5, hmSq, sqD5, sqNext] using h + have hd6 : z6 - m ≤ sqD6 i := by + have h := step_from_bound_square m (loOf i) z5 (sqD5 i) hm (lo_pos i) hinterval.1 hmz5 hd5 (sqD5_le_lo i) + simpa [z6, hmSq, sqD6, sqNext] using h + have hz6le : z6 ≤ m := by + have h0 : z6 - m = 0 := by + have h0le : z6 - m ≤ 0 := by simpa [sqD6_eq_zero i] using hd6 + exact Nat.eq_zero_of_le_zero h0le + exact (Nat.sub_eq_zero_iff_le).1 h0 + have hz6eq : z6 = m := Nat.le_antisymm hz6le hmz6 + have hrun : innerSqrt x = run6From x (seedOf i) := by + calc + innerSqrt x = run6From x (sqrtSeed x) := innerSqrt_eq_run6From x hx + _ = run6From x (seedOf i) := by simp [hseed] + have hrun6 : run6From x (seedOf i) = z6 := by + unfold run6From + simp [z1, z2, z3, z4, z5, z6, z0, bstep] + calc + innerSqrt x = run6From x (seedOf i) := hrun + _ = z6 := hrun6 + _ = m := hz6eq + _ = natSqrt x := by rfl + +private theorem minimal_of_pred_lt + (x r : Nat) + (hpred : r = 0 ∨ (r - 1) * (r - 1) < x) : + ∀ y, x ≤ y * y → r ≤ y := by + intro y hy + by_cases hry : r ≤ y + · exact hry + · have hylt : y < r := Nat.lt_of_not_ge hry + cases hpred with + | inl hr0 => + exact False.elim ((Nat.not_lt_of_ge hylt) (by simp [hr0])) + | inr hpredlt => + have hyle : y ≤ r - 1 := by omega + have hysq : y * y ≤ (r - 1) * (r - 1) := Nat.mul_le_mul hyle hyle + have hcontra : x ≤ (r - 1) * (r - 1) := Nat.le_trans hy hysq + exact False.elim ((Nat.not_lt_of_ge hcontra) hpredlt) + +theorem model_sqrt_up_evm_ceil_u256 + (x : Nat) + (hx256 : x < 2 ^ 256) : + let r := model_sqrt_up_evm x + x ≤ r * r ∧ ∀ y, x ≤ y * y → r ≤ y := by + have hUp : model_sqrt_up_evm x = sqrtUpSpec x := model_sqrt_up_evm_eq_sqrtUpSpec x hx256 + rw [hUp] + unfold sqrtUpSpec + let m := natSqrt x + have hmlo : m * m ≤ x := by simpa [m] using natSqrt_sq_le x + have hmhi : x < (m + 1) * (m + 1) := by simpa [m] using natSqrt_lt_succ_sq x + have hbr : m ≤ innerSqrt x ∧ innerSqrt x ≤ m + 1 := by + simpa [m] using innerSqrt_bracket_u256_all x hx256 + by_cases hlt : innerSqrt x * innerSqrt x < x + · have hinter : innerSqrt x = m := by + have hneq : innerSqrt x ≠ m + 1 := by + intro hce + have hbad : (m + 1) * (m + 1) < x := by simpa [hce] using hlt + exact False.elim ((Nat.not_lt_of_ge (Nat.le_of_lt hmhi)) hbad) + omega + simp [hlt] + constructor + · have hupper : x ≤ (m + 1) * (m + 1) := Nat.le_of_lt hmhi + simpa [hinter] + · exact minimal_of_pred_lt x (innerSqrt x + 1) (Or.inr (by simpa using hlt)) + · simp [hlt] + constructor + · exact Nat.le_of_not_gt hlt + · have hpred : + innerSqrt x = 0 ∨ (innerSqrt x - 1) * (innerSqrt x - 1) < x := by + have hsqCases : m * m < x ∨ m * m = x := Nat.lt_or_eq_of_le hmlo + cases hsqCases with + | inl hsqLt => + have hle : innerSqrt x - 1 ≤ m := by omega + have hsqle : (innerSqrt x - 1) * (innerSqrt x - 1) ≤ m * m := Nat.mul_le_mul hle hle + right + exact Nat.lt_of_le_of_lt hsqle hsqLt + | inr hsqEq => + have hinnerEq : innerSqrt x = m := by + have hsqNat : natSqrt x * natSqrt x = x := by simpa [m] using hsqEq + simpa [m] using innerSqrt_eq_natSqrt_of_square x hx256 hsqNat + by_cases hm0 : m = 0 + · left + simp [hinnerEq, hm0] + · right + have hmPos : 0 < m := Nat.pos_of_ne_zero hm0 + have hpredm : (m - 1) * (m - 1) < x := by + have hsubLt : m - 1 < m := by + simpa [Nat.pred_eq_sub_one] using Nat.pred_lt (Nat.ne_of_gt hmPos) + have hle : (m - 1) * (m - 1) ≤ (m - 1) * m := + Nat.mul_le_mul_left (m - 1) (Nat.sub_le _ _) + have hlt : (m - 1) * m < m * m := Nat.mul_lt_mul_of_pos_right hsubLt hmPos + have hltm : (m - 1) * (m - 1) < m * m := Nat.lt_of_le_of_lt hle hlt + simpa [hsqEq] using hltm + simpa [hinnerEq] using hpredm + exact minimal_of_pred_lt x (innerSqrt x) hpred + +end SqrtGeneratedModel diff --git a/formal/sqrt/SqrtProof/SqrtProof/SqrtCorrect.lean b/formal/sqrt/SqrtProof/SqrtProof/SqrtCorrect.lean new file mode 100644 index 000000000..52c35e99d --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/SqrtCorrect.lean @@ -0,0 +1,432 @@ +/- + Correctness components for Sqrt.sol:_sqrt and sqrt. + + Theorem 1 (innerSqrt_correct): + Lower-bound component: if m² ≤ x then m ≤ innerSqrt(x) (for x > 0). + + Theorem 2 (floorSqrt_correct): + Given a 1-ULP bracket for innerSqrt(x), floorSqrt(x) satisfies + r² ≤ x < (r+1)². +-/ +import Init +import SqrtProof.FloorBound +import SqrtProof.StepMono +import SqrtProof.CertifiedChain + +-- ============================================================================ +-- Part 1: Definitions matching Sqrt.sol EVM semantics +-- ============================================================================ + +/-- The seed: z₀ = 2^⌊(log2(x)+1)/2⌋. For x=0, returns 0. + Matches EVM: shl(shr(1, sub(256, clz(x))), 1) + Since 256 - clz(x) = bitLength(x) = log2(x) + 1 for x > 0. -/ +def sqrtSeed (x : Nat) : Nat := + if x = 0 then 0 + else 1 <<< ((Nat.log2 x + 1) / 2) + +/-- _sqrt: seed + 6 Babylonian steps. Returns z ∈ {isqrt(x), isqrt(x)+1}. -/ +def innerSqrt (x : Nat) : Nat := + if x = 0 then 0 + else + let z := sqrtSeed x + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + let z := bstep x z + z + +/-- sqrt: _sqrt with floor correction. Returns exactly isqrt(x). + Matches: z := sub(z, lt(div(x, z), z)) -/ +def floorSqrt (x : Nat) : Nat := + let z := innerSqrt x + if z = 0 then 0 + else if x / z < z then z - 1 else z + +-- ============================================================================ +-- Part 2: Lower bound (composing Lemma 1) +-- ============================================================================ + +/-- The seed is positive for x > 0. -/ +theorem sqrtSeed_pos (x : Nat) (hx : 0 < x) : + 0 < sqrtSeed x := by + unfold sqrtSeed + simp [Nat.ne_of_gt hx] + rw [Nat.shiftLeft_eq, Nat.one_mul] + exact Nat.lt_of_lt_of_le (by omega : 0 < 1) (Nat.one_le_pow _ 2 (by omega)) + +/-- bstep preserves positivity when x > 0 and z > 0. -/ +theorem bstep_pos (x z : Nat) (hx : 0 < x) (hz : 0 < z) : 0 < bstep x z := by + unfold bstep + -- For x ≥ 1 and z ≥ 1: z + x/z ≥ 2 (since z ≥ 1 and x/z ≥ 1 when z = 1, + -- or z ≥ 2 when x/z = 0). So (z + x/z)/2 ≥ 1. + by_cases hle : x < z + · -- x < z, so x/z = 0. But z ≥ 2 (since x ≥ 1 and x < z means z ≥ 2). + have : x / z = 0 := Nat.div_eq_zero_iff.mpr (Or.inr hle) + omega + · -- x ≥ z, so x/z ≥ 1 + have : 0 < x / z := Nat.div_pos (by omega) hz + omega + +/-- Canonical integer-square-root witness built by simple recursion. + This avoids additional dependencies while giving `m² ≤ x < (m+1)²`. -/ +def natSqrt : Nat → Nat + | 0 => 0 + | n + 1 => + let m := natSqrt n + if (m + 1) * (m + 1) ≤ n + 1 then m + 1 else m + +/-- Correctness spec for `natSqrt`. -/ +theorem natSqrt_spec (n : Nat) : + natSqrt n * natSqrt n ≤ n ∧ n < (natSqrt n + 1) * (natSqrt n + 1) := by + induction n with + | zero => + simp [natSqrt] + | succ n ih => + rcases ih with ⟨ihle, ihlt⟩ + let m := natSqrt n + have ihle' : m * m ≤ n := by simpa [m] using ihle + have ihlt' : n < (m + 1) * (m + 1) := by simpa [m] using ihlt + by_cases hstep : (m + 1) * (m + 1) ≤ n + 1 + · have hn1eq : n + 1 = (m + 1) * (m + 1) := by omega + have hm12 : m + 1 < m + 2 := by omega + have hleft : (m + 1) * (m + 1) < (m + 1) * (m + 2) := + Nat.mul_lt_mul_of_pos_left hm12 (by omega : 0 < m + 1) + have hright : (m + 2) * (m + 1) < (m + 2) * (m + 2) := + Nat.mul_lt_mul_of_pos_left hm12 (by omega : 0 < m + 2) + have hsq_lt : (m + 1) * (m + 1) < (m + 2) * (m + 2) := by + calc + (m + 1) * (m + 1) < (m + 1) * (m + 2) := hleft + _ = (m + 2) * (m + 1) := by rw [Nat.mul_comm] + _ < (m + 2) * (m + 2) := hright + constructor + · simp [natSqrt, m, hstep] + · have hlt2 : n + 1 < (m + 2) * (m + 2) := by simpa [hn1eq] using hsq_lt + simpa [natSqrt, m, hstep, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hlt2 + · have hn1lt : n + 1 < (m + 1) * (m + 1) := Nat.lt_of_not_ge hstep + constructor + · have hmle : m * m ≤ n + 1 := Nat.le_trans ihle' (Nat.le_succ n) + simpa [natSqrt, m, hstep] using hmle + · simpa [natSqrt, m, hstep] using hn1lt + +theorem natSqrt_sq_le (n : Nat) : natSqrt n * natSqrt n ≤ n := + (natSqrt_spec n).1 + +theorem natSqrt_lt_succ_sq (n : Nat) : n < (natSqrt n + 1) * (natSqrt n + 1) := + (natSqrt_spec n).2 + +-- ============================================================================ +-- Part 4: Main theorems +-- ============================================================================ + +/-- innerSqrt gives a lower bound: for any m with m² ≤ x, m ≤ innerSqrt(x). + This follows from 6 applications of babylon_step_floor_bound. -/ +theorem innerSqrt_lower (x m : Nat) (hx : 0 < x) + (hm : m * m ≤ x) : m ≤ innerSqrt x := by + unfold innerSqrt + simp [Nat.ne_of_gt hx] + -- The seed is positive + have hs := sqrtSeed_pos x hx + -- Each bstep preserves positivity (x > 0) + -- Chain: m ≤ bstep x (bstep x (... (bstep x (sqrtSeed x)))) + -- Each step: if m² ≤ x and z > 0, then m ≤ bstep x z + -- bstep is defined in FloorBound + -- babylon_step_floor_bound : m*m ≤ x → 0 < z → m ≤ (z + x/z)/2 + have h1 := bstep_pos x _ hx hs + have h2 := bstep_pos x _ hx h1 + have h3 := bstep_pos x _ hx h2 + have h4 := bstep_pos x _ hx h3 + have h5 := bstep_pos x _ hx h4 + -- Apply floor bound at the last step (z₅ is positive by h5) + exact babylon_step_floor_bound x _ m h5 hm + +/-- Unfolding identity: `innerSqrt` is six steps starting from `sqrtSeed`. -/ +theorem innerSqrt_eq_run6From (x : Nat) (hx : 0 < x) : + innerSqrt x = SqrtCertified.run6From x (sqrtSeed x) := by + unfold innerSqrt SqrtCertified.run6From + simp [Nat.ne_of_gt hx, bstep] + +/-- Finite-certificate upper bound: if `m` is bracketed by the octave certificate, + then six steps from the actual seed satisfy `innerSqrt x ≤ m + 1`. -/ +theorem innerSqrt_upper_cert + (i : Fin 256) (x m : Nat) + (hx : 0 < x) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hseed : sqrtSeed x = SqrtCert.seedOf i) + (hlo : SqrtCert.loOf i ≤ m) + (hhi : m ≤ SqrtCert.hiOf i) : + innerSqrt x ≤ m + 1 := by + have hrun : SqrtCertified.run6From x (SqrtCert.seedOf i) ≤ m + 1 := + SqrtCertified.run6_le_m_plus_one i x m hm hmlo hmhi hlo hhi + calc + innerSqrt x = SqrtCertified.run6From x (sqrtSeed x) := innerSqrt_eq_run6From x hx + _ = SqrtCertified.run6From x (SqrtCert.seedOf i) := by simp [hseed] + _ ≤ m + 1 := hrun + +/-- Certificate-backed 1-ULP bracket for `innerSqrt`. -/ +theorem innerSqrt_bracket_cert + (i : Fin 256) (x m : Nat) + (hx : 0 < x) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hseed : sqrtSeed x = SqrtCert.seedOf i) + (hlo : SqrtCert.loOf i ≤ m) + (hhi : m ≤ SqrtCert.hiOf i) : + m ≤ innerSqrt x ∧ innerSqrt x ≤ m + 1 := by + exact ⟨innerSqrt_lower x m hx hmlo, innerSqrt_upper_cert i x m hx hm hmlo hmhi hseed hlo hhi⟩ + +/-- `sqrtSeed` agrees with the finite-certificate seed on octave `i`. -/ +theorem sqrtSeed_eq_seedOf_of_octave + (i : Fin 256) (x : Nat) + (hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1)) : + sqrtSeed x = SqrtCert.seedOf i := by + have hx : 0 < x := Nat.lt_of_lt_of_le (Nat.two_pow_pos i.val) hOct.1 + have hx0 : x ≠ 0 := Nat.ne_of_gt hx + have hlog : Nat.log2 x = i.val := (Nat.log2_eq_iff hx0).2 hOct + unfold sqrtSeed SqrtCert.seedOf + simp [Nat.ne_of_gt hx, hlog] + +/-- From the certified octave endpoints and `m² ≤ x < (m+1)²`, + derive `m ∈ [loOf i, hiOf i]`. -/ +theorem m_within_cert_interval + (i : Fin 256) (x m : Nat) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1)) : + SqrtCert.loOf i ≤ m ∧ m ≤ SqrtCert.hiOf i := by + have hloSq : SqrtCert.loOf i * SqrtCert.loOf i ≤ 2 ^ i.val := SqrtCert.lo_sq_le_pow2 i + have hloSqX : SqrtCert.loOf i * SqrtCert.loOf i ≤ x := Nat.le_trans hloSq hOct.1 + have hlo : SqrtCert.loOf i ≤ m := by + by_cases h : SqrtCert.loOf i ≤ m + · exact h + · have hlt : m < SqrtCert.loOf i := Nat.lt_of_not_ge h + have hm1 : m + 1 ≤ SqrtCert.loOf i := Nat.succ_le_of_lt hlt + have hm1sq : (m + 1) * (m + 1) ≤ SqrtCert.loOf i * SqrtCert.loOf i := + Nat.mul_le_mul hm1 hm1 + have hm1x : (m + 1) * (m + 1) ≤ x := Nat.le_trans hm1sq hloSqX + exact False.elim ((Nat.not_lt_of_ge hm1x) hmhi) + have hhiSq : 2 ^ (i.val + 1) ≤ (SqrtCert.hiOf i + 1) * (SqrtCert.hiOf i + 1) := + SqrtCert.pow2_succ_le_hi_succ_sq i + have hXHi : x < (SqrtCert.hiOf i + 1) * (SqrtCert.hiOf i + 1) := + Nat.lt_of_lt_of_le hOct.2 hhiSq + have hhi : m ≤ SqrtCert.hiOf i := by + by_cases h : m ≤ SqrtCert.hiOf i + · exact h + · have hlt : SqrtCert.hiOf i < m := Nat.lt_of_not_ge h + have hhi1 : SqrtCert.hiOf i + 1 ≤ m := Nat.succ_le_of_lt hlt + have hhimsq : (SqrtCert.hiOf i + 1) * (SqrtCert.hiOf i + 1) ≤ m * m := + Nat.mul_le_mul hhi1 hhi1 + have hXmm : x < m * m := Nat.lt_of_lt_of_le hXHi hhimsq + exact False.elim ((Nat.not_lt_of_ge hmlo) hXmm) + exact ⟨hlo, hhi⟩ + +/-- Certificate-backed upper bound under octave membership. -/ +theorem innerSqrt_upper_of_octave + (i : Fin 256) (x m : Nat) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1)) : + innerSqrt x ≤ m + 1 := by + have hx : 0 < x := Nat.lt_of_lt_of_le (Nat.two_pow_pos i.val) hOct.1 + have hm : 0 < m := by + by_cases hm0 : m = 0 + · subst hm0 + have hx1 : 1 ≤ x := Nat.succ_le_of_lt hx + have hlt1 : x < 1 := by simpa using hmhi + exact False.elim ((Nat.not_lt_of_ge hx1) hlt1) + · exact Nat.pos_of_ne_zero hm0 + have hseed : sqrtSeed x = SqrtCert.seedOf i := sqrtSeed_eq_seedOf_of_octave i x hOct + have hinterval : SqrtCert.loOf i ≤ m ∧ m ≤ SqrtCert.hiOf i := + m_within_cert_interval i x m hmlo hmhi hOct + exact innerSqrt_upper_cert i x m hx hm hmlo hmhi hseed hinterval.1 hinterval.2 + +/-- Certificate-backed 1-ULP bracket under octave membership. -/ +theorem innerSqrt_bracket_of_octave + (i : Fin 256) (x m : Nat) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1)) : + m ≤ innerSqrt x ∧ innerSqrt x ≤ m + 1 := by + have hx : 0 < x := Nat.lt_of_lt_of_le (Nat.two_pow_pos i.val) hOct.1 + exact ⟨innerSqrt_lower x m hx hmlo, innerSqrt_upper_of_octave i x m hmlo hmhi hOct⟩ + +/-- The floor correction is correct. + Given z > 0, (z-1)² ≤ x < (z+1)², the correction gives isqrt(x). -/ +theorem floor_correction (x z : Nat) (hz : 0 < z) + (hlo : (z - 1) * (z - 1) ≤ x) + (hhi : x < (z + 1) * (z + 1)) : + let r := if x / z < z then z - 1 else z + r * r ≤ x ∧ x < (r + 1) * (r + 1) := by + simp only + by_cases h_lt : x / z < z + · -- x/z < z means z² > x (since z * (x/z) ≤ x < z * z) + simp [h_lt] + have h_zsq : x < z * z := by + have h_euc := Nat.div_add_mod x z + have h_mod := Nat.mod_lt x hz + -- x < z * (x/z + 1) and x/z + 1 ≤ z, so x < z * z + have h1 : x < z * (x / z + 1) := by rw [Nat.mul_add, Nat.mul_one]; omega + exact Nat.lt_of_lt_of_le h1 (Nat.mul_le_mul_left z (by omega)) + constructor + · exact hlo + · have : z - 1 + 1 = z := by omega + rw [this]; exact h_zsq + · -- x/z ≥ z means z² ≤ x + simp [h_lt] + simp only [Nat.not_lt] at h_lt + have h_zsq : z * z ≤ x := by + calc z * z ≤ z * (x / z) := Nat.mul_le_mul_left z h_lt + _ ≤ x := Nat.mul_div_le x z + exact ⟨h_zsq, hhi⟩ + +-- ============================================================================ +-- Named wrappers for the advertised theorem entry points +-- ============================================================================ + +/-- `innerSqrt_correct`: established lower-bound component. + For any witness `m` with `m² ≤ x` and `x > 0`, `innerSqrt x` is at least `m`. -/ +theorem innerSqrt_correct (x m : Nat) (hx : 0 < x) (hm : m * m ≤ x) : + m ≤ innerSqrt x := + innerSqrt_lower x m hx hm + +/-- `floorSqrt_correct`: correction-step correctness under a 1-ULP bracket + for the inner approximation. -/ +theorem floorSqrt_correct (x : Nat) (hz : 0 < innerSqrt x) + (hlo : (innerSqrt x - 1) * (innerSqrt x - 1) ≤ x) + (hhi : x < (innerSqrt x + 1) * (innerSqrt x + 1)) : + let r := floorSqrt x + r * r ≤ x ∧ x < (r + 1) * (r + 1) := by + unfold floorSqrt + simpa [Nat.ne_of_gt hz] using floor_correction x (innerSqrt x) hz hlo hhi + +/-- End-to-end correction theorem from the finite certificate assumptions. -/ +theorem floorSqrt_correct_cert + (i : Fin 256) (x m : Nat) + (hx : 0 < x) + (hm : 0 < m) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hseed : sqrtSeed x = SqrtCert.seedOf i) + (hlo : SqrtCert.loOf i ≤ m) + (hhi : m ≤ SqrtCert.hiOf i) : + let r := floorSqrt x + r * r ≤ x ∧ x < (r + 1) * (r + 1) := by + have hlow : m ≤ innerSqrt x := innerSqrt_lower x m hx hmlo + have hupp : innerSqrt x ≤ m + 1 := innerSqrt_upper_cert i x m hx hm hmlo hmhi hseed hlo hhi + have hz : 0 < innerSqrt x := Nat.lt_of_lt_of_le hm hlow + have hlo' : (innerSqrt x - 1) * (innerSqrt x - 1) ≤ x := by + have hz1 : innerSqrt x - 1 ≤ m := by omega + have hsq : (innerSqrt x - 1) * (innerSqrt x - 1) ≤ m * m := Nat.mul_le_mul hz1 hz1 + exact Nat.le_trans hsq hmlo + have hhi' : x < (innerSqrt x + 1) * (innerSqrt x + 1) := by + have hm1 : m + 1 ≤ innerSqrt x + 1 := by omega + have hsq : (m + 1) * (m + 1) ≤ (innerSqrt x + 1) * (innerSqrt x + 1) := + Nat.mul_le_mul hm1 hm1 + exact Nat.lt_of_lt_of_le hmhi hsq + exact floorSqrt_correct x hz hlo' hhi' + +/-- End-to-end correctness under octave membership plus the witness + `m² ≤ x < (m+1)²` (so `m` is the integer square root witness). -/ +theorem floorSqrt_correct_of_octave + (i : Fin 256) (x m : Nat) + (hmlo : m * m ≤ x) + (hmhi : x < (m + 1) * (m + 1)) + (hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1)) : + let r := floorSqrt x + r * r ≤ x ∧ x < (r + 1) * (r + 1) := by + have hx : 0 < x := Nat.lt_of_lt_of_le (Nat.two_pow_pos i.val) hOct.1 + have hm : 0 < m := by + by_cases hm0 : m = 0 + · subst hm0 + have hx1 : 1 ≤ x := Nat.succ_le_of_lt hx + have hlt1 : x < 1 := by simpa using hmhi + exact False.elim ((Nat.not_lt_of_ge hx1) hlt1) + · exact Nat.pos_of_ne_zero hm0 + have hseed : sqrtSeed x = SqrtCert.seedOf i := sqrtSeed_eq_seedOf_of_octave i x hOct + have hinterval : SqrtCert.loOf i ≤ m ∧ m ≤ SqrtCert.hiOf i := + m_within_cert_interval i x m hmlo hmhi hOct + exact floorSqrt_correct_cert i x m hx hm hmlo hmhi hseed hinterval.1 hinterval.2 + +/-- Universal `_sqrt` bracket on uint256 domain: + choose `m = natSqrt x` and derive `m ≤ innerSqrt x ≤ m+1`. -/ +theorem innerSqrt_bracket_u256 + (x : Nat) + (hx : 0 < x) + (hx256 : x < 2 ^ 256) : + let m := natSqrt x + m ≤ innerSqrt x ∧ innerSqrt x ≤ m + 1 := by + let i : Fin 256 := ⟨Nat.log2 x, (Nat.log2_lt (Nat.ne_of_gt hx)).2 hx256⟩ + let m := natSqrt x + have hmlo : m * m ≤ x := by simpa [m] using natSqrt_sq_le x + have hmhi : x < (m + 1) * (m + 1) := by simpa [m] using natSqrt_lt_succ_sq x + have hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1) := by + have hlog : 2 ^ Nat.log2 x ≤ x ∧ x < 2 ^ (Nat.log2 x + 1) := + (Nat.log2_eq_iff (Nat.ne_of_gt hx)).1 rfl + simpa [i] + exact innerSqrt_bracket_of_octave i x m hmlo hmhi hOct + +/-- Universal `_sqrt` bracket on uint256 domain (including `x = 0`). -/ +theorem innerSqrt_bracket_u256_all + (x : Nat) + (hx256 : x < 2 ^ 256) : + let m := natSqrt x + m ≤ innerSqrt x ∧ innerSqrt x ≤ m + 1 := by + by_cases hx0 : x = 0 + · subst hx0 + simp [natSqrt, innerSqrt] + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + simpa using innerSqrt_bracket_u256 x hx hx256 + +/-- Universal `sqrt` correctness on uint256 domain (Nat model): + for every `x < 2^256`, `floorSqrt x` satisfies the integer-sqrt spec. -/ +theorem floorSqrt_correct_u256 + (x : Nat) + (hx256 : x < 2 ^ 256) : + let r := floorSqrt x + r * r ≤ x ∧ x < (r + 1) * (r + 1) := by + by_cases hx0 : x = 0 + · subst hx0 + simp [floorSqrt, innerSqrt] + · have hx : 0 < x := Nat.pos_of_ne_zero hx0 + let i : Fin 256 := ⟨Nat.log2 x, (Nat.log2_lt (Nat.ne_of_gt hx)).2 hx256⟩ + let m := natSqrt x + have hmlo : m * m ≤ x := by simpa [m] using natSqrt_sq_le x + have hmhi : x < (m + 1) * (m + 1) := by simpa [m] using natSqrt_lt_succ_sq x + have hOct : 2 ^ i.val ≤ x ∧ x < 2 ^ (i.val + 1) := by + have hlog : 2 ^ Nat.log2 x ≤ x ∧ x < 2 ^ (Nat.log2 x + 1) := + (Nat.log2_eq_iff (Nat.ne_of_gt hx)).1 rfl + simpa [i] + exact floorSqrt_correct_of_octave i x m hmlo hmhi hOct + +/-- Canonical witness package for the advertised uint256 statement. -/ +theorem sqrt_witness_correct_u256 + (x : Nat) + (hx256 : x < 2 ^ 256) : + ∃ m, m * m ≤ x ∧ x < (m + 1) * (m + 1) ∧ + m ≤ innerSqrt x ∧ innerSqrt x ≤ m + 1 := by + refine ⟨natSqrt x, natSqrt_sq_le x, natSqrt_lt_succ_sq x, ?_⟩ + simpa using innerSqrt_bracket_u256_all x hx256 + +-- ============================================================================ +-- Summary of proof status +-- ============================================================================ + +/- + PROOF STATUS — ALL COMPLETE (0 sorry): + + ✓ Lemma 1 (Floor Bound): babylon_step_floor_bound + ✓ Lemma 2 (Absorbing Set): babylon_from_ceil, babylon_from_floor + ✓ Step Monotonicity: bstep_mono_x, bstep_mono_z + ✓ Overestimate Contraction: bstep_lt_of_overestimate + ✓ Finite certificate layer: d1..d6 bounds from offline literals + ✓ Lower Bound Chain: innerSqrt_lower (6x babylon_step_floor_bound) + ✓ Finite-Certificate Upper Bound: innerSqrt_upper_cert + ✓ Floor Correction: floor_correction (case split on x/z < z) + ✓ Octave Wiring: innerSqrt_upper_of_octave, floorSqrt_correct_of_octave + ✓ Universal uint256 wrappers: innerSqrt_bracket_u256_all, floorSqrt_correct_u256 + ✓ Theorem wrappers: innerSqrt_correct, floorSqrt_correct, floorSqrt_correct_cert +-/ diff --git a/formal/sqrt/SqrtProof/SqrtProof/StepMono.lean b/formal/sqrt/SqrtProof/SqrtProof/StepMono.lean new file mode 100644 index 000000000..986443c98 --- /dev/null +++ b/formal/sqrt/SqrtProof/SqrtProof/StepMono.lean @@ -0,0 +1,80 @@ +/- + Step monotonicity for overestimates. + When z² > x, the Babylonian step is non-decreasing in z. +-/ +import Init +import SqrtProof.FloorBound + +-- ============================================================================ +-- Core: x/z ≤ x/(z+1) + 1 for overestimates +-- ============================================================================ + +theorem div_drop_le_one (x z : Nat) (hz : 0 < z) (hov : x < z * z) : + x / z ≤ x / (z + 1) + 1 := by + by_cases hq : x / z = 0 + · -- 0 ≤ anything + 1 + rw [hq]; exact Nat.zero_le _ + · have hq_pos : 0 < x / z := Nat.pos_of_ne_zero hq + have hq_lt : x / z < z := Nat.div_lt_of_lt_mul hov + have h_mul_le : z * (x / z) ≤ x := Nat.mul_div_le x z + have h_x_ge_z : z ≤ x := Nat.le_trans (Nat.le_mul_of_pos_right z hq_pos) h_mul_le + -- Show (x/z - 1) * (z+1) ≤ x, then by le_div_iff: x/z - 1 ≤ x/(z+1). + -- Expand: (x/z - 1) * (z+1) = (x/z - 1) * z + (x/z - 1) * 1 + -- Part 1: (x/z - 1) * z = z * (x/z) - z * 1 ≤ x - z + have h_part1 : (x / z - 1) * z ≤ x - z := by + rw [Nat.mul_comm (x / z - 1) z, Nat.mul_sub z (x / z) 1, Nat.mul_one] + exact Nat.sub_le_sub_right h_mul_le z + -- Part 2: (x/z - 1) ≤ z - 2 + -- Combined: (x/z-1)*z + (x/z-1) ≤ (x-z) + (z-2) ≤ x + have h_prod : (x / z - 1) * (z + 1) ≤ x := by + rw [Nat.mul_add, Nat.mul_one] + -- (x/z-1)*z + (x/z - 1) ≤ (x - z) + (z - 2) = x - 2 ≤ x + -- Need: (x/z - 1) ≤ z - 2 and x - z + (z - 2) ≤ x + -- x/z - 1 ≤ z - 2: from x/z ≤ z - 1 (i.e., x/z < z) + -- x - z + (z - 2) = x - 2 ≤ x: true for x ≥ 0 + -- But Nat subtraction: x - z + (z - 2) might not equal x - 2. + -- Instead: (x/z-1)*z ≤ x - z and x/z - 1 ≤ z - 2. + -- sum ≤ x - z + z - 2 = x - 2 ≤ x (for z ≥ 2) + -- For z = 1: x/z = x, hq_lt says x < 1, so x = 0 and hq says x/1 = 0. Contradiction. + have hz2 : z ≥ 2 := by omega -- from hq_pos (x/z ≥ 1) and hq_lt (x/z < z) + -- (x/z-1)*z + (x/z-1) ≤ (x-z) + (z-2) + have : x / z - 1 ≤ z - 2 := by omega + -- Need: x - z + (z - 2) ≤ x. Since z ≥ 2: z - 2 ≤ z, x - z + (z-2) ≤ x - 2 ≤ x. + omega + have := (Nat.le_div_iff_mul_le (by omega : 0 < z + 1)).mpr h_prod + omega + +theorem sum_nondec_step (x z : Nat) (hz : 0 < z) (hov : x < z * z) : + z + x / z ≤ (z + 1) + x / (z + 1) := by + have := div_drop_le_one x z hz hov; omega + +-- ============================================================================ +-- Step monotonicity +-- ============================================================================ + +theorem bstep_mono_x {x₁ x₂ z : Nat} (hx : x₁ ≤ x₂) (_hz : 0 < z) : + bstep x₁ z ≤ bstep x₂ z := by + unfold bstep + have : x₁ / z ≤ x₂ / z := Nat.div_le_div_right hx; omega + +theorem bstep_mono_z (x z₁ z₂ : Nat) (hz : 0 < z₁) + (hov : x < z₁ * z₁) (hle : z₁ ≤ z₂) : + bstep x z₁ ≤ bstep x z₂ := by + unfold bstep + suffices z₁ + x / z₁ ≤ z₂ + x / z₂ by + exact Nat.div_le_div_right this + induction z₂ with + | zero => omega + | succ n ih => + by_cases h : z₁ ≤ n + · have hn : 0 < n := by omega + have hov_n : x < n * n := + Nat.lt_of_lt_of_le hov (Nat.mul_le_mul h h) + exact Nat.le_trans (ih h) (sum_nondec_step x n hn hov_n) + · have h_eq : z₁ = n + 1 := by omega + subst h_eq; omega + +theorem bstep_lt_of_overestimate (x z : Nat) (_hz : 0 < z) (hov : x < z * z) : + bstep x z < z := by + unfold bstep + have : x / z < z := Nat.div_lt_of_lt_mul hov; omega diff --git a/formal/sqrt/SqrtProof/lake-manifest.json b/formal/sqrt/SqrtProof/lake-manifest.json new file mode 100644 index 000000000..9452fee48 --- /dev/null +++ b/formal/sqrt/SqrtProof/lake-manifest.json @@ -0,0 +1,5 @@ +{"version": "1.1.0", + "packagesDir": ".lake/packages", + "packages": [], + "name": "SqrtProof", + "lakeDir": ".lake"} diff --git a/formal/sqrt/SqrtProof/lakefile.toml b/formal/sqrt/SqrtProof/lakefile.toml new file mode 100644 index 000000000..14f5ded88 --- /dev/null +++ b/formal/sqrt/SqrtProof/lakefile.toml @@ -0,0 +1,10 @@ +name = "SqrtProof" +version = "0.1.0" +defaultTargets = ["SqrtProof"] + +[[lean_lib]] +name = "SqrtProof" + +[[lean_exe]] +name = "sqrt-model" +root = "Main" diff --git a/formal/sqrt/SqrtProof/lean-toolchain b/formal/sqrt/SqrtProof/lean-toolchain new file mode 100644 index 000000000..4c685fa08 --- /dev/null +++ b/formal/sqrt/SqrtProof/lean-toolchain @@ -0,0 +1 @@ +leanprover/lean4:v4.28.0 diff --git a/formal/sqrt/generate_sqrt512_model.py b/formal/sqrt/generate_sqrt512_model.py new file mode 100644 index 000000000..6198ed851 --- /dev/null +++ b/formal/sqrt/generate_sqrt512_model.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +""" +Generate Lean model of 512Math sqrt functions from Yul IR. + +This script extracts `_sqrt_babylonianStep`, `_sqrt_baseCase`, +`_sqrt_karatsubaQuotient`, `_sqrt_correction`, `_sqrt`, `sqrt`, +`sqrtUp`, `wrap_sqrt512`, and `wrap_osqrtUp` from the Yul IR produced +by `forge inspect` on Sqrt512Wrapper and emits opcode-faithful uint256 +EVM Lean definitions (norm model suppressed via evm_only=True since the +proofs bridge the EVM model directly). + +By keeping the sub-functions in `function_order`, the pipeline emits +separate models for each. `model_sqrt512_evm` calls into sub-models +rather than inlining their bodies, producing smaller Lean terms. +The public wrappers (`sqrt`, `osqrtUp`) call `model_sqrt512_evm` +and inline all other helpers (256-bit sqrt, _mul, _gt, _add) as raw +opcodes. + +All compiler-generated helper functions (type conversions, wrapping +arithmetic, library calls) are inlined to raw opcodes automatically. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Allow importing the shared module from formal/ +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from yul_to_lean import ModelConfig, run + +CONFIG = ModelConfig( + function_order=( + "_sqrt_babylonianStep", "_sqrt_baseCase", + "_sqrt_karatsubaQuotient", "_sqrt_correction", + "_sqrt", + # 256-bit sqrt/sqrtUp from Sqrt.sol — kept as named sub-models so the + # public wrappers don't inline the full Babylonian chain, which would + # cause (kernel) deep recursion in the Lean proofs. + "sqrt", "sqrtUp", + "wrap_sqrt512", "wrap_osqrtUp", + ), + model_names={ + "_sqrt_babylonianStep": "model_bstep", + "_sqrt_baseCase": "model_innerSqrt", + "_sqrt_karatsubaQuotient": "model_karatsubaQuotient", + "_sqrt_correction": "model_sqrtCorrection", + "_sqrt": "model_sqrt512", + "sqrt": "model_sqrt256_floor", + "sqrtUp": "model_sqrt256_up", + "wrap_sqrt512": "model_sqrt512_wrapper", + "wrap_osqrtUp": "model_osqrtUp", + }, + header_comment="Auto-generated from Solidity 512Math._sqrt assembly and assignment flow.", + generator_label="formal/sqrt/generate_sqrt512_model.py", + extra_norm_ops={}, + extra_lean_defs="", + norm_rewrite=None, + inner_fn="_sqrt", + n_params={ + "_sqrt_babylonianStep": 2, + "_sqrt_baseCase": 1, + "_sqrt_karatsubaQuotient": 3, + "_sqrt_correction": 4, + "_sqrt": 2, + "sqrt": 1, + "sqrtUp": 1, + "wrap_sqrt512": 2, + "wrap_osqrtUp": 2, + }, + keep_solidity_locals=True, + # 256-bit sqrt/sqrtUp share names with 512-bit wrappers; use + # exclude_known to select the leaf (256-bit) versions that do NOT + # call the already-targeted _sqrt (512-bit). + exclude_known=frozenset({"sqrt", "sqrtUp"}), + # Suppress norm models for functions whose proofs bridge the EVM model + # directly (the norm model uses unbounded Nat which doesn't match EVM). + skip_norm=frozenset({"sqrt", "sqrtUp", "wrap_sqrt512", "wrap_osqrtUp"}), + default_source_label="src/utils/512Math.sol", + default_namespace="Sqrt512GeneratedModel", + default_output="formal/sqrt/Sqrt512Proof/Sqrt512Proof/GeneratedSqrt512Model.lean", + cli_description="Generate Lean model of 512Math._sqrt from Yul IR", +) + + +if __name__ == "__main__": + raise SystemExit(run(CONFIG)) diff --git a/formal/sqrt/generate_sqrt_cert.py b/formal/sqrt/generate_sqrt_cert.py new file mode 100644 index 000000000..ba1504d0f --- /dev/null +++ b/formal/sqrt/generate_sqrt_cert.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +""" +Generate finite-certificate tables for the sqrt formal proof. + +For each of 256 octaves (n = 0..255), where octave n contains x in [2^n, 2^(n+1) - 1]: + - loOf(n) = isqrt(2^n) -- lower bound on isqrt(x) + - hiOf(n) = isqrt(2^(n+1) - 1) -- upper bound on isqrt(x) + - seedOf(n) = 1 << ((n+1)/2) -- sqrt seed for octave n + - d1(n): first-step error bound from algebraic formula + - d2..d6(n): chained via nextD(lo, d) = d^2/(2*lo) + 1 + +The d1 bound uses the quadratic identity: + (z1 - m)^2 ≤ (s - m)^2 + 2*m*(m+1)/s * ((s-m)^2 / (2*m)) + ≤ maxAbs^2 + 2*hi +where maxAbs = max(|s-lo|, |hi-s|). +So d1 = (maxAbs^2 + 2*hi) / (2*seed). + +Also generates a Sqrt512Cert namespace with fixed-seed certificates for +octaves 254/255, used by the 512-bit sqrt proof. +""" + +import argparse +import sys + + +def isqrt(x): + """Integer square root (floor).""" + if x <= 0: + return 0 + if x < 4: + return 1 + n = x.bit_length() + z = 1 << ((n + 1) // 2) + while True: + z1 = (z + x // z) // 2 + if z1 >= z: + break + z = z1 + while z * z > x: + z -= 1 + while (z + 1) ** 2 <= x: + z += 1 + return z + + +def sqrt_step(x, z): + """One Babylonian step: floor((z + floor(x/z)) / 2)""" + if z == 0: + return 0 + return (z + x // z) // 2 + + +def sqrt_seed(n): + """Seed for octave n: 1 << ((n+1)/2).""" + return 1 << ((n + 1) // 2) + + +def next_d(lo, d): + """Error recurrence: d^2/(2*lo) + 1.""" + if lo == 0: + return d * d + 1 + return d * d // (2 * lo) + 1 + + +def compute_maxabs(lo, hi, s): + """max(|s - lo|, |hi - s|)""" + return max(abs(s - lo), abs(hi - s)) + + +def compute_d1(lo, hi, s): + """Analytic d1 bound: + d1 = floor((maxAbs^2 + 2*hi) / (2*s)) + """ + maxAbs = compute_maxabs(lo, hi, s) + numerator = maxAbs * maxAbs + 2 * hi + denominator = 2 * s + if denominator == 0: + return 0 + return numerator // denominator + + +def main(): + parser = argparse.ArgumentParser( + description="Generate finite-certificate tables for sqrt formal proof" + ) + parser.add_argument( + "--output", + default="SqrtProof/SqrtProof/FiniteCert.lean", + help="Output Lean file path (default: SqrtProof/SqrtProof/FiniteCert.lean)", + ) + args = parser.parse_args() + + lo_table = [] + hi_table = [] + + for n in range(256): + lo = isqrt(1 << n) + hi = isqrt((1 << (n + 1)) - 1) + lo_table.append(lo) + hi_table.append(hi) + + # Verify basic properties + for n in range(256): + lo = lo_table[n] + hi = hi_table[n] + assert lo * lo <= (1 << n), f"lo^2 > 2^n at n={n}" + assert (1 << (n + 1)) <= (hi + 1) ** 2, f"2^(n+1) > (hi+1)^2 at n={n}" + assert lo <= hi, f"lo > hi at n={n}" + + # Compute certificate for all 256 octaves + all_ok = True + d_data = {} # n -> (d1, ..., d6) + + for n in range(256): + lo = lo_table[n] + hi = hi_table[n] + seed = sqrt_seed(n) + + d1 = compute_d1(lo, hi, seed) + d2 = next_d(lo, d1) + d3 = next_d(lo, d2) + d4 = next_d(lo, d3) + d5 = next_d(lo, d4) + d6 = next_d(lo, d5) + d_data[n] = (d1, d2, d3, d4, d5, d6) + + if d6 > 1: + print(f"FAIL d6: n={n}, d1={d1}, d2={d2}, d3={d3}, " + f"d4={d4}, d5={d5}, d6={d6}, lo={lo}") + all_ok = False + + # Check side conditions: dk <= lo for k=1..5 + for k, dk in enumerate([d1, d2, d3, d4, d5], 1): + if dk > lo: + print(f"SIDE FAIL: n={n}, d{k}={dk} > lo={lo}") + all_ok = False + + if all_ok: + print(f"All octaves 0-255 pass: d6 <= 1, all side conditions OK.") + else: + print("SOME OCTAVES FAIL.") + + # Exhaustive verification for small octaves to confirm d1 bound + print(f"\nExhaustive verification of d1 for octaves 0-30...") + for n in range(min(31, 256)): + lo = lo_table[n] + hi = hi_table[n] + seed = sqrt_seed(n) + d1_cert = d_data[n][0] + + for m in range(lo, hi + 1): + x_lo_m = max(m * m, 1 << n) + x_hi_m = min((m + 1) ** 2 - 1, (1 << (n + 1)) - 1) + if x_lo_m > x_hi_m: + continue + z1 = sqrt_step(x_hi_m, seed) # max z1 by mono in x + actual_d1 = max(0, z1 - m) + if actual_d1 > d1_cert: + print(f" D1 FAIL: n={n}, m={m}, z1={z1}, actual_d1={actual_d1}, cert={d1_cert}") + all_ok = False + print(" d1 exhaustive check done.") + + # Spot-check d1 for large octaves + import random + random.seed(42) + print("\nSpot-checking d1 for large octaves...") + for n in range(100, 256, 10): + lo = lo_table[n] + hi = hi_table[n] + seed = sqrt_seed(n) + d1_cert = d_data[n][0] + + for m in [lo, hi, lo + (hi - lo) // 3, lo + 2 * (hi - lo) // 3]: + x_max = min((m + 1) ** 2 - 1, (1 << (n + 1)) - 1) + x_min = max(m ** 2, 1 << n) + if x_min > x_max: + continue + z1 = sqrt_step(x_max, seed) + actual_d1 = max(0, z1 - m) + if actual_d1 > d1_cert: + print(f" SPOT FAIL: n={n}, m={m}, z1={z1}, actual_d1={actual_d1}, cert={d1_cert}") + all_ok = False + print(" Spot check done.") + + # Summary + print(f"\n--- Summary (octaves 0-255) ---") + for k in range(6): + vals = [d_data[n][k] for n in range(256)] + mx = max(vals) + mi = vals.index(mx) + print(f" Max d{k+1}: {mx} at n={mi}") + + # Print d1/lo ratios for a few octaves + print(f"\n--- d1/lo ratios ---") + for n in [0, 2, 5, 10, 20, 50, 85, 100, 123, 170, 200, 255]: + lo = lo_table[n] + d1 = d_data[n][0] + if lo > 0: + print(f" n={n}: lo={lo}, d1={d1}, d1/lo={d1/lo:.6f}") + + # Verify Sqrt512Cert: fixed-seed certificate for octaves 254/255 + FIXED_SEED = lo_table[255] # = isqrt(2^255) + print(f"\n--- Sqrt512Cert verification ---") + print(f" FIXED_SEED = {FIXED_SEED}") + assert FIXED_SEED == hi_table[254], f"FIXED_SEED != hi(254)" + assert FIXED_SEED == lo_table[255], f"FIXED_SEED != lo(255)" + + for octave in [254, 255]: + lo = lo_table[octave] + hi = hi_table[octave] + ma = compute_maxabs(lo, hi, FIXED_SEED) + fd1 = compute_d1(lo, hi, FIXED_SEED) + fd2 = next_d(lo, fd1) + fd3 = next_d(lo, fd2) + fd4 = next_d(lo, fd3) + fd5 = next_d(lo, fd4) + fd6 = next_d(lo, fd5) + print(f" octave {octave}: lo={lo}, hi={hi}, maxAbs={ma}") + print(f" fd1={fd1}, fd2={fd2}, fd3={fd3}, fd4={fd4}, fd5={fd5}, fd6={fd6}") + assert fd6 <= 1, f"fd6 > 1 for octave {octave}!" + for k, dk in enumerate([fd1, fd2, fd3, fd4, fd5], 1): + assert dk <= lo, f"fd{k} > lo for octave {octave}!" + print(f" All checks pass.") + + # Generate Lean output + if all_ok: + generate_lean_file(lo_table, hi_table, d_data, FIXED_SEED, args.output) + + return 0 if all_ok else 1 + + +def generate_lean_file(lo_table, hi_table, d_data, fixed_seed, outpath): + """Generate the FiniteCert.lean file with SqrtCert and Sqrt512Cert namespaces.""" + print(f"\nGenerating {outpath}...") + + def fmt_array(name, values, comment=""): + lines = [] + if comment: + lines.append(f"/-- {comment} -/") + lines.append(f"def {name} : Array Nat := #[") + for i, v in enumerate(values): + comma = "," if i < len(values) - 1 else "" + lines.append(f" {v}{comma}") + lines.append("]") + return "\n".join(lines) + + # ========================================================================= + # SqrtCert namespace + # ========================================================================= + + content = f"""import Init + +/- + Finite certificate for sqrt upper bound, covering all 256 octaves. + + For each octave i (n = 0..255), the tables provide: + - loOf(i): lower bound on isqrt(x) for x in [2^i, 2^(i+1)-1] + - hiOf(i): upper bound on isqrt(x) + - seedOf(i): the sqrt seed for the octave = 1 <<< ((i+1)/2) + - maxAbs(i): max(|seed - lo|, |hi - seed|) + - d1(i): first-step error bound (analytic) + - nextD, d2..d6: chained error recurrence d^2/(2*lo) + 1 + + All 256 octaves verified: d6 <= 1 and dk <= lo for k=1..5. + + Auto-generated by formal/sqrt/generate_sqrt_cert.py — do not edit by hand. +-/ + +namespace SqrtCert + +set_option maxRecDepth 1000000 + +{fmt_array("loTable", lo_table, "Lower bounds on isqrt(x) for octaves 0..255.")} + +{fmt_array("hiTable", hi_table, "Upper bounds on isqrt(x) for octaves 0..255.")} + +def seedOf (i : Fin 256) : Nat := + 1 <<< ((i.val + 1) / 2) + +def loOf (i : Fin 256) : Nat := + loTable[i.val]! + +def hiOf (i : Fin 256) : Nat := + hiTable[i.val]! + +def maxAbs (i : Fin 256) : Nat := + max (seedOf i - loOf i) (hiOf i - seedOf i) + +def d1 (i : Fin 256) : Nat := + (maxAbs i * maxAbs i + 2 * hiOf i) / (2 * seedOf i) + +def nextD (lo d : Nat) : Nat := + d * d / (2 * lo) + 1 + +def d2 (i : Fin 256) : Nat := + nextD (loOf i) (d1 i) + +def d3 (i : Fin 256) : Nat := + nextD (loOf i) (d2 i) + +def d4 (i : Fin 256) : Nat := + nextD (loOf i) (d3 i) + +def d5 (i : Fin 256) : Nat := + nextD (loOf i) (d4 i) + +def d6 (i : Fin 256) : Nat := + nextD (loOf i) (d5 i) + +theorem lo_pos : ∀ i : Fin 256, 0 < loOf i := by + decide + +theorem d1_le_lo : ∀ i : Fin 256, d1 i ≤ loOf i := by + decide + +theorem d2_le_lo : ∀ i : Fin 256, d2 i ≤ loOf i := by + decide + +theorem d3_le_lo : ∀ i : Fin 256, d3 i ≤ loOf i := by + decide + +theorem d4_le_lo : ∀ i : Fin 256, d4 i ≤ loOf i := by + decide + +theorem d5_le_lo : ∀ i : Fin 256, d5 i ≤ loOf i := by + decide + +theorem d6_le_one : ∀ i : Fin 256, d6 i ≤ 1 := by + decide + +theorem lo_sq_le_pow2 : ∀ i : Fin 256, loOf i * loOf i ≤ 2 ^ i.val := by + decide + +theorem pow2_succ_le_hi_succ_sq : + ∀ i : Fin 256, 2 ^ (i.val + 1) ≤ (hiOf i + 1) * (hiOf i + 1) := by + decide + +end SqrtCert + +-- ============================================================================ +-- Sqrt512Cert: fixed-seed certificates for octaves 254/255 +-- Used by the 512-bit sqrt proof (Sqrt512Proof). +-- ============================================================================ + +namespace Sqrt512Cert + +open SqrtCert + +/-- The fixed Newton seed used by 512-bit sqrt: isqrt(2^255). + Equals hiOf(254) = loOf(255) in the finite certificate tables. -/ +def FIXED_SEED : Nat := {fixed_seed} + +def lo254 : Nat := loOf ⟨254, by omega⟩ +def hi254 : Nat := hiOf ⟨254, by omega⟩ +def maxAbs254 : Nat := max (FIXED_SEED - lo254) (hi254 - FIXED_SEED) +def fd1_254 : Nat := (maxAbs254 * maxAbs254 + 2 * hi254) / (2 * FIXED_SEED) +def fd2_254 : Nat := nextD lo254 fd1_254 +def fd3_254 : Nat := nextD lo254 fd2_254 +def fd4_254 : Nat := nextD lo254 fd3_254 +def fd5_254 : Nat := nextD lo254 fd4_254 +def fd6_254 : Nat := nextD lo254 fd5_254 + +set_option maxRecDepth 100000 in +theorem fd6_254_le_one : fd6_254 ≤ 1 := by decide +set_option maxRecDepth 100000 in +theorem fd1_254_le_lo : fd1_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd2_254_le_lo : fd2_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd3_254_le_lo : fd3_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd4_254_le_lo : fd4_254 ≤ lo254 := by decide +set_option maxRecDepth 100000 in +theorem fd5_254_le_lo : fd5_254 ≤ lo254 := by decide +theorem lo254_pos : 0 < lo254 := lo_pos ⟨254, by omega⟩ + +def lo255 : Nat := loOf ⟨255, by omega⟩ +def hi255 : Nat := hiOf ⟨255, by omega⟩ +def maxAbs255 : Nat := max (FIXED_SEED - lo255) (hi255 - FIXED_SEED) +def fd1_255 : Nat := (maxAbs255 * maxAbs255 + 2 * hi255) / (2 * FIXED_SEED) +def fd2_255 : Nat := nextD lo255 fd1_255 +def fd3_255 : Nat := nextD lo255 fd2_255 +def fd4_255 : Nat := nextD lo255 fd3_255 +def fd5_255 : Nat := nextD lo255 fd4_255 +def fd6_255 : Nat := nextD lo255 fd5_255 + +set_option maxRecDepth 100000 in +theorem fd6_255_le_one : fd6_255 ≤ 1 := by decide +set_option maxRecDepth 100000 in +theorem fd1_255_le_lo : fd1_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd2_255_le_lo : fd2_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd3_255_le_lo : fd3_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd4_255_le_lo : fd4_255 ≤ lo255 := by decide +set_option maxRecDepth 100000 in +theorem fd5_255_le_lo : fd5_255 ≤ lo255 := by decide +theorem lo255_pos : 0 < lo255 := lo_pos ⟨255, by omega⟩ + +end Sqrt512Cert +""" + + with open(outpath, "w") as f: + f.write(content) + print(f" Written to {outpath}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/formal/sqrt/generate_sqrt_model.py b/formal/sqrt/generate_sqrt_model.py new file mode 100644 index 000000000..57ce7d0a8 --- /dev/null +++ b/formal/sqrt/generate_sqrt_model.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" +Generate Lean models of Sqrt.sol from Yul IR. + +This script extracts `_sqrt`, `sqrt`, and `sqrtUp` from the Yul IR produced by +`forge inspect` on a wrapper contract and emits Lean definitions for: +- opcode-faithful uint256 EVM semantics, and +- normalized Nat semantics. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +# Allow importing the shared module from formal/ +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from yul_to_lean import ModelConfig, run + +CONFIG = ModelConfig( + function_order=("_sqrt", "sqrt", "sqrtUp"), + model_names={ + "_sqrt": "model_sqrt", + "sqrt": "model_sqrt_floor", + "sqrtUp": "model_sqrt_up", + }, + header_comment="Auto-generated from Solidity Sqrt assembly and assignment flow.", + generator_label="formal/sqrt/generate_sqrt_model.py", + extra_norm_ops={}, + extra_lean_defs="", + norm_rewrite=None, + inner_fn="_sqrt", + default_source_label="src/vendor/Sqrt.sol", + default_namespace="SqrtGeneratedModel", + default_output="formal/sqrt/SqrtProof/SqrtProof/GeneratedSqrtModel.lean", + cli_description="Generate Lean model of Sqrt.sol functions from Yul IR", +) + + +if __name__ == "__main__": + raise SystemExit(run(CONFIG)) diff --git a/formal/yul_to_lean.py b/formal/yul_to_lean.py new file mode 100644 index 000000000..98eebc508 --- /dev/null +++ b/formal/yul_to_lean.py @@ -0,0 +1,2101 @@ +""" +Shared infrastructure for generating Lean models from Yul IR. + +Provides: +- Yul tokenizer and recursive-descent parser +- AST types (IntLit, Var, Call, Assignment, FunctionModel) +- Yul → FunctionModel conversion (copy propagation + demangling) +- Lean expression emission +- Common Lean source scaffolding +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import pathlib +import re +import sys +import warnings +from collections import Counter +from dataclasses import dataclass +from typing import Callable + + +class ParseError(RuntimeError): + pass + + +# --------------------------------------------------------------------------- +# AST nodes (shared by Yul parser and Lean emitter) +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class IntLit: + value: int + + +@dataclass(frozen=True) +class Var: + name: str + + +@dataclass(frozen=True) +class Call: + name: str + args: tuple["Expr", ...] + + +Expr = IntLit | Var | Call + + +@dataclass(frozen=True) +class Assignment: + target: str + expr: Expr + + +@dataclass(frozen=True) +class ConditionalBlock: + """An ``if cond { ... }`` or ``if/else`` block assigning to declared vars. + + ``condition`` is the Yul condition expression. + ``assignments`` are the assignments inside the if-body. + ``modified_vars`` lists the Solidity-level variable names that the block + may modify (used for Lean tuple-destructuring emission). + ``else_vars`` are the variable names for pass-through values when + there is no else-body (the pre-if values). + ``else_assignments`` are assignments for the else-body when present + (from ``switch`` or if/else constructs). + """ + condition: Expr + assignments: tuple[Assignment, ...] + modified_vars: tuple[str, ...] + else_vars: tuple[str, ...] | None = None + else_assignments: tuple[Assignment, ...] | None = None + + +# A model statement is either a plain assignment or a conditional block. +ModelStatement = Assignment | ConditionalBlock + + +@dataclass(frozen=True) +class FunctionModel: + fn_name: str + assignments: tuple[ModelStatement, ...] + param_names: tuple[str, ...] = ("x",) + return_names: tuple[str, ...] = ("z",) + + +# --------------------------------------------------------------------------- +# Yul tokenizer +# --------------------------------------------------------------------------- + +YUL_TOKEN_RE = re.compile( + r""" + (?P///[^\n]*) + | (?P//[^\n]*) + | (?P\s+) + | (?P"(?:[^"\\]|\\.)*") + | (?P0x[0-9a-fA-F]+) + | (?P[0-9]+) + | (?P:=) + | (?P->) + | (?P[A-Za-z_.$][A-Za-z0-9_.$]*) + | (?P\{) + | (?P\}) + | (?P\() + | (?P\)) + | (?P,) + | (?P:) +""", + re.VERBOSE, +) + +_TOKEN_KIND_MAP = { + "linecomment": None, + "blockcomment": None, + "ws": None, + "string": "string", + "hex": "num", + "num": "num", + "assign": ":=", + "arrow": "->", + "ident": "ident", + "lbrace": "{", + "rbrace": "}", + "lparen": "(", + "rparen": ")", + "comma": ",", + "colon": ":", +} + + +def tokenize_yul(source: str) -> list[tuple[str, str]]: + """Tokenize Yul IR source into a list of (kind, text) pairs. + + Comments and whitespace are discarded. String literals are kept as + single tokens so that braces inside ``"contract Foo {..."`` never + confuse downstream code. + """ + tokens: list[tuple[str, str]] = [] + pos = 0 + length = len(source) + while pos < length: + m = YUL_TOKEN_RE.match(source, pos) + if not m: + snippet = source[pos : pos + 30] + raise ParseError(f"Yul tokenizer stuck at position {pos}: {snippet!r}") + pos = m.end() + raw_kind = m.lastgroup + assert raw_kind is not None + kind = _TOKEN_KIND_MAP[raw_kind] + if kind is None: + continue + text = m.group() + tokens.append((kind, text)) + return tokens + + +# --------------------------------------------------------------------------- +# Yul recursive-descent parser +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ParsedIfBlock: + """Raw parsed ``if cond { body }`` or ``switch`` from Yul, before demangling. + + When ``else_body`` is present, this represents an if/else or a + ``switch expr case 0 { else_body } default { body }`` construct. + """ + condition: Expr + body: tuple[tuple[str, Expr], ...] + has_leave: bool = False + else_body: tuple[tuple[str, Expr], ...] | None = None + + +# A raw parsed statement is either an assignment or an if-block. +RawStatement = tuple[str, Expr] | ParsedIfBlock + + +@dataclass +class YulFunction: + """Parsed representation of a single Yul ``function`` definition.""" + yul_name: str + params: list[str] + rets: list[str] + assignments: list[RawStatement] + expr_stmts: list[Expr] | None = None + + @property + def param(self) -> str: + """Backward-compat accessor for single-parameter functions.""" + if len(self.params) != 1: + raise ValueError( + f"YulFunction {self.yul_name!r} has {len(self.params)} params; " + f"use .params instead of .param" + ) + return self.params[0] + + @property + def ret(self) -> str: + """Backward-compat accessor for single-return functions.""" + if len(self.rets) != 1: + raise ValueError( + f"YulFunction {self.yul_name!r} has {len(self.rets)} return vars; " + f"use .rets instead of .ret" + ) + return self.rets[0] + + +class YulParser: + """Recursive-descent parser over a pre-tokenized Yul token stream. + + Only the subset of Yul needed for our extraction is handled: function + definitions, ``let``/bare assignments, blocks, and ``leave``. + + Control flow (``if``, ``switch``, ``for``) is **rejected** — its + presence would make the straight-line Lean model incomplete and + silently wrong. Bare expression-statements are tracked and warned + about since they may indicate side-effectful operations the model + does not capture. + """ + + def __init__(self, tokens: list[tuple[str, str]]) -> None: + self.tokens = tokens + self.i = 0 + self._expr_stmts: list[Expr] = [] + + def _at_end(self) -> bool: + return self.i >= len(self.tokens) + + def _peek(self) -> tuple[str, str] | None: + if self._at_end(): + return None + return self.tokens[self.i] + + def _peek_kind(self) -> str | None: + tok = self._peek() + return tok[0] if tok else None + + def _pop(self) -> tuple[str, str]: + tok = self._peek() + if tok is None: + raise ParseError("Unexpected end of Yul token stream") + self.i += 1 + return tok + + def _expect(self, kind: str) -> str: + k, text = self._pop() + if k != kind: + raise ParseError(f"Expected {kind!r}, got {k!r} ({text!r})") + return text + + def _expect_ident(self) -> str: + return self._expect("ident") + + def _skip_until_matching_brace(self) -> None: + self._expect("{") + depth = 1 + while depth > 0: + k, _ = self._pop() + if k == "{": + depth += 1 + elif k == "}": + depth -= 1 + + def _parse_expr(self) -> Expr: + kind, text = self._pop() + if kind == "num": + return IntLit(int(text, 0)) + if kind == "ident": + if self._peek_kind() == "(": + self._pop() + args: list[Expr] = [] + if self._peek_kind() != ")": + while True: + args.append(self._parse_expr()) + if self._peek_kind() == ",": + self._pop() + continue + break + self._expect(")") + return Call(text, tuple(args)) + return Var(text) + if kind == "string": + return Var(text) + raise ParseError(f"Expected expression, got {kind!r} ({text!r})") + + def _parse_let(self, results: list) -> None: + """Parse a ``let`` statement and append to *results*. + + Handles three forms: + - ``let x := expr`` — single-value assignment + - ``let a, b, c := call()`` — multi-value; each target gets a + synthetic ``__component_N(call)`` wrapper to distinguish them + - ``let x`` — bare declaration (zero-init, skipped) + """ + self._pop() # consume 'let' + target = self._expect_ident() + if self._peek_kind() == ",": + all_targets: list[str] = [target] + while self._peek_kind() == ",": + self._pop() + all_targets.append(self._expect_ident()) + self._expect(":=") + expr = self._parse_expr() + for idx, t in enumerate(all_targets): + results.append((t, Call(f"__component_{idx}", (expr,)))) + elif self._peek_kind() == ":=": + self._pop() + expr = self._parse_expr() + results.append((target, expr)) + else: + # Bare declaration: ``let x`` (zero-initialized, skip) + pass + + def _parse_body_assignments(self) -> list[RawStatement]: + results: list[RawStatement] = [] + + while not self._at_end() and self._peek_kind() != "}": + kind = self._peek_kind() + + if kind == "{": + self._pop() + results.extend(self._parse_body_assignments()) + self._expect("}") + continue + + if kind == "ident" and self.tokens[self.i][1] == "let": + self._parse_let(results) + continue + + if kind == "ident" and self.tokens[self.i][1] == "leave": + self._pop() + continue + + if kind == "ident" and self.tokens[self.i][1] == "function": + self._skip_function_def() + continue + + if kind == "ident" and self.tokens[self.i][1] == "if": + self._pop() # consume 'if' + condition = self._parse_expr() + self._expect("{") + body, has_leave = self._parse_if_body_assignments() + self._expect("}") + results.append(ParsedIfBlock( + condition=condition, + body=tuple(body), + has_leave=has_leave, + )) + continue + + if kind == "ident" and self.tokens[self.i][1] == "switch": + self._pop() # consume 'switch' + condition = self._parse_expr() + # We support exactly one form of switch: + # switch e case 0 { else_body } default { if_body } + # (branches may appear in either order). Anything else + # is rejected loudly. + case0_body: list[tuple[str, Expr]] | None = None + case0_leave = False + default_body: list[tuple[str, Expr]] | None = None + default_leave = False + n_branches = 0 + while (not self._at_end() + and self._peek_kind() == "ident" + and self.tokens[self.i][1] in ("case", "default")): + branch = self.tokens[self.i][1] + self._pop() # consume 'case' or 'default' + if branch == "case": + case_val = self._parse_expr() + cv = _try_const_eval(case_val) + if cv != 0: + raise ParseError( + f"switch case value {case_val!r} is not 0. " + f"Only 'switch e case 0 {{ ... }} default " + f"{{ ... }}' is supported." + ) + if case0_body is not None: + raise ParseError( + "Duplicate 'case 0' in switch statement." + ) + self._expect("{") + case0_body, case0_leave = self._parse_if_body_assignments() + self._expect("}") + else: # default + if default_body is not None: + raise ParseError( + "Duplicate 'default' in switch statement." + ) + self._expect("{") + default_body, default_leave = self._parse_if_body_assignments() + self._expect("}") + # default must be the last branch. + n_branches += 1 + break + n_branches += 1 + # Reject trailing case branches after default. + if (default_body is not None + and not self._at_end() + and self._peek_kind() == "ident" + and self.tokens[self.i][1] in ("case", "default")): + raise ParseError( + "'default' must be the last branch in a switch." + ) + if n_branches == 0: + raise ParseError("switch with no case/default branches") + if n_branches != 2 or case0_body is None or default_body is None: + raise ParseError( + f"switch must have exactly 'case 0' + 'default' " + f"(got {n_branches} branch(es), case0=" + f"{'present' if case0_body is not None else 'missing'}" + f", default=" + f"{'present' if default_body is not None else 'missing'}" + f")." + ) + # Map to ParsedIfBlock: condition != 0 → default (if-body), + # condition == 0 → case 0 (else-body). + if_body = tuple(default_body) if default_body else () + else_body = tuple(case0_body) if case0_body else None + results.append(ParsedIfBlock( + condition=condition, + body=if_body, + has_leave=default_leave or case0_leave, + else_body=else_body, + )) + continue + + if kind == "ident" and self.tokens[self.i][1] == "for": + raise ParseError( + f"Control flow statement 'for' found in function body. " + f"Only straight-line code and if/switch blocks are " + f"supported for Lean model generation." + ) + + if kind == "ident" and self.i + 1 < len(self.tokens) and self.tokens[self.i + 1][0] == ":=": + target = self._expect_ident() + self._expect(":=") + expr = self._parse_expr() + results.append((target, expr)) + continue + + if kind == "ident" or kind == "num": + expr = self._parse_expr() + self._expr_stmts.append(expr) + continue + + tok = self._pop() + warnings.warn( + f"Unrecognized token {tok!r} in function body was skipped. " + f"This may indicate a Yul IR construct the parser does not " + f"handle.", + stacklevel=2, + ) + + return results + + def _parse_if_body_assignments( + self, + ) -> tuple[list[tuple[str, Expr]], bool]: + """Parse the body of an ``if`` block. + + Only bare assignments (``target := expr``) are expected inside + if-bodies in the Yul IR patterns we handle. ``let`` declarations + are also accepted (they are locals scoped to the if-body that the + compiler may introduce). + + Returns ``(assignments, has_leave)`` where *has_leave* indicates + that a ``leave`` statement (early return) was encountered. + """ + results: list[tuple[str, Expr]] = [] + has_leave = False + while not self._at_end() and self._peek_kind() != "}": + kind = self._peek_kind() + + if kind == "{": + self._pop() + inner_results, inner_leave = self._parse_if_body_assignments() + results.extend(inner_results) + has_leave = has_leave or inner_leave + self._expect("}") + continue + + if kind == "ident" and self.tokens[self.i][1] == "let": + self._parse_let(results) + continue + + if kind == "ident" and self.i + 1 < len(self.tokens) and self.tokens[self.i + 1][0] == ":=": + target = self._expect_ident() + self._expect(":=") + expr = self._parse_expr() + results.append((target, expr)) + continue + + if kind == "ident" and self.tokens[self.i][1] == "leave": + self._pop() + has_leave = True + continue + + if kind == "ident" or kind == "num": + expr = self._parse_expr() + self._expr_stmts.append(expr) + continue + + tok = self._pop() + warnings.warn( + f"Unrecognized token {tok!r} in if-body was skipped.", + stacklevel=2, + ) + return results, has_leave + + def _skip_function_def(self) -> None: + self._pop() # consume 'function' + self._expect_ident() + self._expect("(") + while self._peek_kind() != ")": + self._pop() + self._expect(")") + if self._peek_kind() == "->": + self._pop() + self._expect_ident() + while self._peek_kind() == ",": + self._pop() + self._expect_ident() + self._skip_until_matching_brace() + + def parse_function(self) -> YulFunction: + fn_kw = self._expect_ident() + assert fn_kw == "function", f"Expected 'function', got {fn_kw!r}" + yul_name = self._expect_ident() + self._expect("(") + params: list[str] = [] + if self._peek_kind() != ")": + params.append(self._expect_ident()) + while self._peek_kind() == ",": + self._pop() + params.append(self._expect_ident()) + self._expect(")") + rets: list[str] = [] + if self._peek_kind() == "->": + self._pop() + rets.append(self._expect_ident()) + while self._peek_kind() == ",": + self._pop() + rets.append(self._expect_ident()) + self._expect("{") + self._expr_stmts = [] + assignments = self._parse_body_assignments() + self._expect("}") + if self._expr_stmts: + descriptions = [] + for e in self._expr_stmts[:3]: + if isinstance(e, Call): + descriptions.append(f"{e.name}(...)") + else: + descriptions.append(repr(e)) + summary = ", ".join(descriptions) + if len(self._expr_stmts) > 3: + summary += ", ..." + warnings.warn( + f"Function {yul_name!r} contains " + f"{len(self._expr_stmts)} expression-statement(s) " + f"not captured in the model: [{summary}]. " + f"If any have side effects (sstore, log, revert, ...) " + f"the model may be incomplete.", + stacklevel=2, + ) + return YulFunction( + yul_name=yul_name, + params=params, + rets=rets, + assignments=assignments, + expr_stmts=self._expr_stmts if self._expr_stmts else None, + ) + + def _count_params_at(self, idx: int) -> int: + """Count the number of parameters of the function at token index ``idx``. + + Scans the parenthesized parameter list without advancing the main + cursor. Returns the count of comma-separated identifiers. + """ + # idx points to 'function', idx+1 is the name, idx+2 should be '(' + j = idx + 2 + if j >= len(self.tokens) or self.tokens[j][0] != "(": + return 0 + j += 1 # skip '(' + if j < len(self.tokens) and self.tokens[j][0] == ")": + return 0 + count = 1 + while j < len(self.tokens) and self.tokens[j][0] != ")": + if self.tokens[j][0] == ",": + count += 1 + j += 1 + return count + + def find_function( + self, sol_fn_name: str, *, n_params: int | None = None, + known_yul_names: set[str] | None = None, + exclude_known: bool = False, + ) -> YulFunction: + """Find and parse ``function fun_{sol_fn_name}_(...)``. + + When *n_params* is set and multiple candidates match the name + pattern, only those with exactly *n_params* parameters are kept. + + When *known_yul_names* is set and still ambiguous, prefer + candidates whose body references at least one of the given Yul + function names. This disambiguates e.g. ``sqrt(uint512)`` (which + calls ``_sqrt``) from ``Sqrt.sqrt(uint256)`` (which does not). + + When *exclude_known* is True, the filter is inverted: prefer + candidates whose body does NOT reference any known Yul name. + This selects leaf functions (e.g. 256-bit ``Sqrt.sqrt``) over + higher-level wrappers that call into already-targeted functions. + + Raises on zero or ambiguous matches. + """ + target_prefix = f"fun_{sol_fn_name}_" + matches: list[int] = [] + + for idx in range(len(self.tokens) - 1): + if ( + self.tokens[idx] == ("ident", "function") + and self.tokens[idx + 1][0] == "ident" + and self.tokens[idx + 1][1].startswith(target_prefix) + and self.tokens[idx + 1][1][len(target_prefix):].isdigit() + ): + matches.append(idx) + + if not matches: + raise ParseError( + f"Yul function for '{sol_fn_name}' not found " + f"(expected pattern fun_{sol_fn_name}_)" + ) + + if n_params is not None and len(matches) > 1: + filtered = [m for m in matches if self._count_params_at(m) == n_params] + if filtered: + matches = filtered + + if known_yul_names and len(matches) > 1: + if exclude_known: + filtered = [m for m in matches + if not self._body_references_any(m, known_yul_names)] + else: + filtered = [m for m in matches + if self._body_references_any(m, known_yul_names)] + if filtered: + matches = filtered + + if len(matches) > 1: + names = [self.tokens[m + 1][1] for m in matches] + raise ParseError( + f"Multiple Yul functions match '{sol_fn_name}': {names}. " + f"Rename wrapper functions to avoid collisions " + f"(e.g. prefix with 'wrap_')." + ) + + self.i = matches[0] + return self.parse_function() + + def _body_references_any(self, fn_start: int, yul_names: set[str]) -> bool: + """Check if the function at *fn_start* references any identifier in *yul_names*.""" + depth = 0 + started = False + for j in range(fn_start, len(self.tokens)): + k, text = self.tokens[j] + if k == "{": + depth += 1 + started = True + elif k == "}": + depth -= 1 + if started and depth == 0: + return False + elif k == "ident" and text in yul_names: + return True + return False + + def collect_all_functions(self) -> dict[str, YulFunction]: + """Parse all function definitions in the token stream. + + Functions whose bodies contain unsupported constructs (``switch``, + ``for``, etc.) are silently skipped — they cannot be inlined but + that is fine for model generation. + + Warnings about expression-statements (``revert``, ``mstore``, etc.) + are suppressed because these auxiliary functions are parsed only for + inlining, not for direct modelling. + """ + functions: dict[str, YulFunction] = {} + while not self._at_end(): + if ( + self._peek_kind() == "ident" + and self.tokens[self.i][1] == "function" + ): + saved_i = self.i + saved_stmts = self._expr_stmts + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fn = self.parse_function() + functions[fn.yul_name] = fn + except ParseError: + # Unsupported body — skip this function. + self.i = saved_i + self._pop() # consume 'function' + self._expect_ident() # consume name + self._expect("(") + while self._peek_kind() != ")": + self._pop() + self._expect(")") + if self._peek_kind() == "->": + self._pop() + self._expect_ident() + while self._peek_kind() == ",": + self._pop() + self._expect_ident() + self._skip_until_matching_brace() + finally: + self._expr_stmts = saved_stmts + else: + self._pop() + return functions + + +# --------------------------------------------------------------------------- +# Yul → FunctionModel conversion +# --------------------------------------------------------------------------- + + +def demangle_var( + name: str, + param_vars: list[str], + return_vars: list[str] | str, + *, + keep_solidity_locals: bool = False, +) -> str | None: + """Map a Yul variable name back to its Solidity-level name. + + Returns the cleaned name, or None if the variable is a compiler temporary + that should be copy-propagated away. + + ``param_vars`` is a list of Yul parameter variable names (supports + multi-parameter functions). + + ``return_vars`` is a list of Yul return variable names (or a single + string for backward compatibility with single-return functions). + + When *keep_solidity_locals* is True, variables matching the + ``var__`` pattern (compiler representation of + Solidity-declared locals) are kept in the model even if they are + not the function parameter or return variable. + """ + if isinstance(return_vars, str): + return_vars = [return_vars] + if name in param_vars or name in return_vars: + m = re.fullmatch(r"var_(\w+?)_\d+", name) + return m.group(1) if m else name + if name.startswith("usr$"): + return name[4:] + if keep_solidity_locals: + m = re.fullmatch(r"var_(\w+?)_\d+", name) + if m: + return m.group(1) + return None + + +def rename_expr(expr: Expr, var_map: dict[str, str], fn_map: dict[str, str]) -> Expr: + if isinstance(expr, IntLit): + return expr + if isinstance(expr, Var): + return Var(var_map.get(expr.name, expr.name)) + if isinstance(expr, Call): + new_name = fn_map.get(expr.name, expr.name) + new_args = tuple(rename_expr(a, var_map, fn_map) for a in expr.args) + return Call(new_name, new_args) + raise TypeError(f"Unsupported Expr node: {type(expr)}") + + +def substitute_expr(expr: Expr, subst: dict[str, Expr]) -> Expr: + if isinstance(expr, IntLit): + return expr + if isinstance(expr, Var): + return subst.get(expr.name, expr) + if isinstance(expr, Call): + return Call(expr.name, tuple(substitute_expr(a, subst) for a in expr.args)) + raise TypeError(f"Unsupported Expr node: {type(expr)}") + + +# --------------------------------------------------------------------------- +# Function inlining +# --------------------------------------------------------------------------- + +_inline_counter = 0 + + +def _gensym(prefix: str) -> str: + """Generate a unique variable name for inlined function locals.""" + global _inline_counter + _inline_counter += 1 + return f"_inline_{prefix}_{_inline_counter}" + + +def _try_const_eval(expr: Expr) -> int | None: + """Try to evaluate an expression to a constant integer. + + Returns ``None`` if the expression contains variables or unsupported + operations. Used for resolving constant memory addresses in + mstore/mload folding. + """ + if isinstance(expr, IntLit): + return expr.value + if isinstance(expr, Call): + if expr.name == "add" and len(expr.args) == 2: + a = _try_const_eval(expr.args[0]) + b = _try_const_eval(expr.args[1]) + if a is not None and b is not None: + return (a + b) % (2 ** 256) + if expr.name == "sub" and len(expr.args) == 2: + a = _try_const_eval(expr.args[0]) + b = _try_const_eval(expr.args[1]) + if a is not None and b is not None: + return (a + 2 ** 256 - b) % (2 ** 256) + # Handle __ite(cond, if_val, else_val): if both branches + # evaluate to the same constant, the result is that constant + # regardless of the condition. + if expr.name == "__ite" and len(expr.args) == 3: + if_val = _try_const_eval(expr.args[1]) + else_val = _try_const_eval(expr.args[2]) + if if_val is not None and else_val is not None and if_val == else_val: + return if_val + return None + + +def _inline_single_call( + fn: YulFunction, + args: tuple[Expr, ...], + fn_table: dict[str, YulFunction], + depth: int, + max_depth: int, + mstore_sink: list[tuple[str, Expr]] | None = None, +) -> Expr | tuple[Expr, ...]: + """Inline one function call, returning its return-value expression(s). + + Builds a substitution from parameters → argument expressions, then + processes the function body sequentially (same semantics as copy-prop). + Each local variable gets a unique gensym name to avoid clashes with + the caller's scope. + + When *mstore_sink* is not None, ``mstore(addr, val)`` expression- + statements from inlined functions are collected as synthetic + assignments ``(gensym_name, Call("__mstore", [addr, val]))``. The + caller is responsible for injecting these into the outer function's + assignment list so that ``yul_function_to_model`` can resolve + ``mload`` calls lazily during copy propagation. + """ + if fn.expr_stmts: + # Filter out mstore calls when we have a sink to capture them. + unhandled = [ + e for e in fn.expr_stmts + if not (mstore_sink is not None + and isinstance(e, Call) + and e.name == "mstore" + and len(e.args) == 2) + ] + if unhandled: + descriptions = [] + for e in unhandled[:3]: + if isinstance(e, Call): + descriptions.append(f"{e.name}(...)") + else: + descriptions.append(repr(e)) + summary = ", ".join(descriptions) + if len(unhandled) > 3: + summary += ", ..." + warnings.warn( + f"Inlining function {fn.yul_name!r} which contains " + f"{len(unhandled)} unhandled expression-statement(s): " + f"[{summary}]. If any have side effects (sstore, log, " + f"revert, ...) the inlined model may be incomplete.", + stacklevel=3, + ) + + subst: dict[str, Expr] = {} + for param, arg_expr in zip(fn.params, args): + subst[param] = arg_expr + # Also seed return variables with zero (they're typically zero-initialized) + for r in fn.rets: + if r not in subst: + subst[r] = IntLit(0) + + leave_cond: Expr | None = None # set when an if-block with leave is encountered + leave_subst: dict[str, Expr] | None = None + + for stmt in fn.assignments: + if isinstance(stmt, ParsedIfBlock): + # Evaluate condition + cond = substitute_expr(stmt.condition, subst) + cond = inline_calls(cond, fn_table, depth + 1, max_depth, + mstore_sink=mstore_sink) + # Process if-body assignments into a separate subst branch. + if_subst = dict(subst) + # Track mstore count to detect conditional memory writes. + pre_if_sink_len = len(mstore_sink) if mstore_sink is not None else 0 + for target, raw_expr in stmt.body: + expr = substitute_expr(raw_expr, if_subst) + expr = inline_calls(expr, fn_table, depth + 1, max_depth, + mstore_sink=mstore_sink) + if_subst[target] = expr + + # Reject conditional memory writes — they can't be modeled + # faithfully without tracking memory state per branch. + if mstore_sink is not None and len(mstore_sink) > pre_if_sink_len: + raise ParseError( + f"Conditional memory write detected in {fn.yul_name!r}: " + f"{len(mstore_sink) - pre_if_sink_len} mstore(s) emitted " + f"inside an if-block body. Restructure the wrapper so " + f"memory writes occur outside conditionals." + ) + + # Also process else_body if present (from switch). + if stmt.else_body is not None: + else_subst = dict(subst) + pre_else_sink_len = len(mstore_sink) if mstore_sink is not None else 0 + for target, raw_expr in stmt.else_body: + expr = substitute_expr(raw_expr, else_subst) + expr = inline_calls(expr, fn_table, depth + 1, max_depth, + mstore_sink=mstore_sink) + else_subst[target] = expr + if mstore_sink is not None and len(mstore_sink) > pre_else_sink_len: + raise ParseError( + f"Conditional memory write detected in " + f"{fn.yul_name!r}: mstore(s) emitted inside " + f"an else-body. Restructure the wrapper so " + f"memory writes occur outside conditionals." + ) + + if stmt.has_leave: + # The if-block contains ``leave`` (early return). Save + # the if-branch return values; remaining assignments + # after this if-block form the else branch. + leave_cond = cond + leave_subst = if_subst + # Don't update subst — remaining assignments use the + # pre-if state (the "else" path where the condition is false). + elif stmt.else_body is not None: + # If/else or switch: merge both branches with __ite. + all_targets: list[str] = [] + seen: set[str] = set() + for target, _ in (*stmt.body, *stmt.else_body): + if target not in seen: + seen.add(target) + all_targets.append(target) + for target in all_targets: + pre_val = subst.get(target, IntLit(0)) + if_val = if_subst.get(target, pre_val) + else_val = else_subst.get(target, pre_val) + if if_val is not else_val: + subst[target] = Call("__ite", (cond, if_val, else_val)) + elif if_val is not pre_val: + subst[target] = if_val + else: + # Normal if-block (no leave, no else): take the if-branch value. + for target, _raw_expr in stmt.body: + if_val = if_subst[target] + orig_val = subst.get(target, IntLit(0)) + if if_val is not orig_val: + subst[target] = if_val + else: + target, raw_expr = stmt + expr = substitute_expr(raw_expr, subst) + expr = inline_calls(expr, fn_table, depth + 1, max_depth, + mstore_sink=mstore_sink) + # Gensym: rename non-param, non-return locals to avoid clashes + if target not in fn.params and target not in fn.rets: + new_name = _gensym(target) + subst[target] = Var(new_name) + # Re-substitute the expression under the new name + # (it was already substituted, so just store it) + subst[new_name] = expr + else: + subst[target] = expr + + # Resolve any gensym'd variables remaining in return expressions. + # Iterate because gensym'd vars may reference other gensym'd vars. + def _resolve(e: Expr, s: dict[str, Expr]) -> Expr: + for _ in range(20): + prev = e + e = substitute_expr(e, s) + if e is prev: + break + return e + + # Emit mstore effects AFTER the full subst chain is built. + # 1. Collect effects from this function's own expr_stmts. + # 2. Resolve all sink entries through subst to eliminate gensyms. + if mstore_sink is not None: + # Step 1: emit this function's own mstore effects. + for e in (fn.expr_stmts or []): + if isinstance(e, Call) and e.name == "mstore" and len(e.args) == 2: + addr_expr = _resolve(substitute_expr(e.args[0], subst), subst) + val_expr = _resolve(substitute_expr(e.args[1], subst), subst) + syn_name = _gensym("__mstore") + mstore_sink.append( + (syn_name, Call("__mstore", (addr_expr, val_expr))) + ) + + # Step 2: resolve all sink entries through this level's subst. + for i in range(len(mstore_sink)): + name, val = mstore_sink[i] + if isinstance(val, Call) and val.name == "__mstore": + new_args = tuple(_resolve(a, subst) for a in val.args) + if any(na is not oa for na, oa in zip(new_args, val.args)): + mstore_sink[i] = (name, Call("__mstore", new_args)) + + def _get_ret(r: str) -> Expr: + else_val = _resolve(subst.get(r, IntLit(0)), subst) + if leave_cond is not None and leave_subst is not None: + if_val = _resolve(leave_subst.get(r, IntLit(0)), leave_subst) + resolved_cond = _resolve(leave_cond, leave_subst) + return Call("__ite", (resolved_cond, if_val, else_val)) + return else_val + + if len(fn.rets) == 1: + return _get_ret(fn.rets[0]) + return tuple(_get_ret(r) for r in fn.rets) + + +def inline_calls( + expr: Expr, + fn_table: dict[str, YulFunction], + depth: int = 0, + max_depth: int = 20, + mstore_sink: list[tuple[str, Expr]] | None = None, +) -> Expr: + """Recursively inline function calls in an expression. + + Walks the expression tree. When a ``Call`` targets a function in + *fn_table*, its body is inlined via sequential substitution. + ``__component_N`` wrappers (from multi-value ``let``) are resolved + to the Nth return value of the inlined function. + + When *mstore_sink* is not None, ``mstore`` side effects from inlined + functions are collected (see ``_inline_single_call``). + """ + if depth > max_depth: + return expr + if isinstance(expr, (IntLit, Var)): + return expr + if isinstance(expr, Call): + # Handle __component_N(Call(fn, ...)) for multi-return. + # Must check BEFORE recursively inlining arguments, because + # we need to inline the inner call as multi-return to extract + # the Nth component. + m = re.fullmatch(r"__component_(\d+)", expr.name) + if m and len(expr.args) == 1 and isinstance(expr.args[0], Call): + idx = int(m.group(1)) + inner = expr.args[0] + # Recursively inline the inner call's arguments first + inner_args = tuple(inline_calls(a, fn_table, depth, + mstore_sink=mstore_sink) + for a in inner.args) + if inner.name in fn_table: + result = _inline_single_call( + fn_table[inner.name], inner_args, fn_table, depth + 1, + max_depth, mstore_sink=mstore_sink, + ) + if isinstance(result, tuple): + return result[idx] if idx < len(result) else expr + return result # single-return; component_0 = the value + # Inner call not in table — rebuild with inlined args + return Call(expr.name, (Call(inner.name, inner_args),)) + + # Recurse into arguments + args = tuple(inline_calls(a, fn_table, depth, + mstore_sink=mstore_sink) for a in expr.args) + + # Direct call to a collected function + if expr.name in fn_table: + fn = fn_table[expr.name] + result = _inline_single_call(fn, args, fn_table, depth + 1, + max_depth, mstore_sink=mstore_sink) + if isinstance(result, tuple): + return result[0] # single-call context; take first return + return result + + return Call(expr.name, args) + raise TypeError(f"Unsupported Expr node: {type(expr)}") + + +def _inline_yul_function( + yf: YulFunction, + fn_table: dict[str, YulFunction], +) -> YulFunction: + """Apply ``inline_calls`` to every expression in a YulFunction. + + When inlined functions contain ``mstore`` expression-statements, they + are collected and injected as synthetic ``__mstore`` assignments into + the outer function's assignment list. This enables lazy ``mload`` + resolution during ``yul_function_to_model``'s copy propagation. + """ + # Shared sink for mstore effects from all inlined functions. + # Effects are injected into the assignment list at the point they + # are collected (not prepended) so that variables they reference + # are already defined during copy propagation. + mstore_sink: list[tuple[str, Expr]] = [] + + new_assignments: list[RawStatement] = [] + for stmt in yf.assignments: + if isinstance(stmt, ParsedIfBlock): + pre_len = len(mstore_sink) + new_cond = inline_calls(stmt.condition, fn_table, + mstore_sink=mstore_sink) + new_body: list[tuple[str, Expr]] = [] + for target, raw_expr in stmt.body: + new_body.append((target, inline_calls(raw_expr, fn_table, + mstore_sink=mstore_sink))) + # Inject any mstore effects collected during this statement. + new_assignments.extend(mstore_sink[pre_len:]) + new_assignments.append(ParsedIfBlock( + condition=new_cond, + body=tuple(new_body), + has_leave=stmt.has_leave, + )) + else: + target, raw_expr = stmt + pre_len = len(mstore_sink) + inlined = inline_calls(raw_expr, fn_table, + mstore_sink=mstore_sink) + # Inject any mstore effects collected during this inlining. + new_assignments.extend(mstore_sink[pre_len:]) + new_assignments.append((target, inlined)) + + return YulFunction( + yul_name=yf.yul_name, + params=yf.params, + rets=yf.rets, + assignments=new_assignments, + ) + + +def yul_function_to_model( + yf: YulFunction, + sol_fn_name: str, + fn_map: dict[str, str], + keep_solidity_locals: bool = False, +) -> FunctionModel: + """Convert a parsed YulFunction into a FunctionModel. + + Performs copy propagation to eliminate compiler temporaries and renames + variables/calls back to Solidity-level names. + + Validates: + - Multi-assigned compiler temporaries are flagged (copy propagation is + still correct for sequential code, but the situation is unusual). + - The return variable is recognized and assigned in the model. + """ + # ------------------------------------------------------------------ + # Pre-pass: count how many times each variable is assigned. + # A compiler temporary assigned more than once is unusual and could + # indicate a naming-convention change that made a real variable look + # like a temporary. + # ------------------------------------------------------------------ + assign_counts: Counter[str] = Counter() + for stmt in yf.assignments: + if isinstance(stmt, ParsedIfBlock): + for target, _ in stmt.body: + assign_counts[target] += 1 + else: + target, _ = stmt + assign_counts[target] += 1 + + var_map: dict[str, str] = {} + subst: dict[str, Expr] = {} + + for name in [*yf.params, *yf.rets]: + clean = demangle_var(name, yf.params, yf.rets, keep_solidity_locals=keep_solidity_locals) + if clean: + var_map[name] = clean + + # Save param names before SSA processing may rename them. + param_names = tuple(var_map[p] for p in yf.params) + + # ------------------------------------------------------------------ + # SSA state: track assignment count per clean name so that + # reassigned variables get distinct Lean names (_1, _2, ...). + # Parameters start at count 1 (the function-parameter binding). + # ------------------------------------------------------------------ + ssa_count: dict[str, int] = {} + for name in yf.params: + clean = var_map.get(name) + if clean: + ssa_count[clean] = 1 + + assignments: list[ModelStatement] = [] + warned_multi: set[str] = set() + + def _freeze_refs(expr: Expr) -> Expr: + """Replace Var refs to Solidity-level vars with current Lean names, + and rename function calls through ``fn_map``. + + Called when a compiler temporary is copy-propagated. By + resolving Solidity-level ``Var`` nodes to their *current* Lean + name at copy-propagation time we "freeze" the reference, + preventing a later SSA rename of the same variable from + changing what the expression points to. + + Also renames function calls (e.g. ``fun__sqrt_4544`` → ``model_sqrt512``) + so they are correct if the expression is later substituted into a + real variable's assignment without going through ``rename_expr``. + """ + if isinstance(expr, IntLit): + return expr + if isinstance(expr, Var): + lean_name = var_map.get(expr.name) + if lean_name is not None: + return Var(lean_name) + return expr + if isinstance(expr, Call): + new_args = tuple(_freeze_refs(a) for a in expr.args) + new_name = fn_map.get(expr.name, expr.name) + return Call(new_name, new_args) + return expr + + def _process_assignment( + target: str, raw_expr: Expr, *, inside_conditional: bool = False, + ) -> Assignment | None: + """Process a single raw assignment through copy-prop and demangling. + + Returns an Assignment if the target is a real variable, or None if + it was copy-propagated into ``subst``. + """ + expr = substitute_expr(raw_expr, subst) + + clean = demangle_var(target, yf.params, yf.rets, keep_solidity_locals=keep_solidity_locals) + if clean is None: + if assign_counts[target] > 1 and target not in warned_multi: + warned_multi.add(target) + warnings.warn( + f"Variable {target!r} in {sol_fn_name!r} is classified " + f"as a compiler temporary (copy-propagated) but is " + f"assigned {assign_counts[target]} times. Sequential " + f"propagation preserves semantics for straight-line " + f"code, but this is unusual — verify the Yul IR to " + f"confirm this is not a misclassified user variable.", + stacklevel=2, + ) + if isinstance(expr, Call) and expr.name.startswith("zero_value_for_split_"): + subst[target] = IntLit(0) + else: + subst[target] = _freeze_refs(expr) + return None + + # Rename the RHS expression BEFORE updating var_map so that + # self-references (e.g. ``x := f(x)``) resolve to the + # *previous* binding, not the one being created. + skip_zero = isinstance(expr, IntLit) and expr.value == 0 + if not skip_zero: + expr = rename_expr(expr, var_map, fn_map) + + # SSA: compute the Lean target name. Inside conditional + # blocks, Lean's scoped ``let`` handles shadowing, so we + # use the base clean name directly. + if not inside_conditional: + ssa_count[clean] = ssa_count.get(clean, 0) + 1 + if ssa_count[clean] == 1: + ssa_name = clean + else: + ssa_name = f"{clean}_{ssa_count[clean] - 1}" + else: + ssa_name = clean + + # Update var_map AFTER rename_expr. + var_map[target] = ssa_name + + if skip_zero: + return None + + return Assignment(target=ssa_name, expr=expr) + + for stmt in yf.assignments: + if isinstance(stmt, ParsedIfBlock): + # Process the if-block: apply copy-prop/demangling to + # condition and body, then emit a ConditionalBlock. + cond = substitute_expr(stmt.condition, subst) + cond = rename_expr(cond, var_map, fn_map) + + # Save pre-if Lean names so the else-tuple can reference + # the values that were live *before* the if-body ran. + pre_if_names: dict[str, str] = {} + + body_assignments: list[Assignment] = [] + for target, raw_expr in stmt.body: + clean = demangle_var( + target, yf.params, yf.rets, + keep_solidity_locals=keep_solidity_locals, + ) + if clean is not None and clean not in pre_if_names: + pre_if_names[clean] = var_map.get(target, clean) + a = _process_assignment( + target, raw_expr, inside_conditional=True, + ) + if a is not None: + body_assignments.append(a) + if body_assignments: + # Deduplicate while preserving order. + seen_vars: set[str] = set() + modified_list: list[str] = [] + for a in body_assignments: + if a.target not in seen_vars: + seen_vars.add(a.target) + modified_list.append(a.target) + modified = tuple(modified_list) + + # Build else_vars from pre-if state (may differ from + # modified_vars when SSA is active). + else_vars_t = tuple( + pre_if_names.get(v, v) for v in modified_list + ) + else_vars = ( + else_vars_t if else_vars_t != modified else None + ) + + # Process else_body if present (from switch). + else_assgn: tuple[Assignment, ...] | None = None + if stmt.else_body is not None: + else_assignments_list: list[Assignment] = [] + for target, raw_expr in stmt.else_body: + a = _process_assignment( + target, raw_expr, inside_conditional=True, + ) + if a is not None: + else_assignments_list.append(a) + if else_assignments_list: + else_assgn = tuple(else_assignments_list) + # Ensure modified_vars covers vars from both + # branches. + for a in else_assignments_list: + if a.target not in seen_vars: + seen_vars.add(a.target) + modified_list.append(a.target) + modified = tuple(modified_list) + # When else_assignments are present, else_vars + # are not used (the else branch has its own + # computed values). + else_vars = None + + assignments.append(ConditionalBlock( + condition=cond, + assignments=tuple(body_assignments), + modified_vars=modified, + else_vars=else_vars, + else_assignments=else_assgn, + )) + + # After the if-block the Lean tuple-destructuring + # creates fresh bindings with the base clean names. + # Reset var_map and ssa_count accordingly so that + # subsequent references and assignments are correct. + modified_set = set(modified_list) + all_body_targets = list(stmt.body) + if stmt.else_body is not None: + all_body_targets.extend(stmt.else_body) + for target_name, _ in all_body_targets: + c = demangle_var( + target_name, yf.params, yf.rets, + keep_solidity_locals=keep_solidity_locals, + ) + if c is not None and c in modified_set: + var_map[target_name] = c + ssa_count[c] = 1 + continue + + target, raw_expr = stmt + a = _process_assignment(target, raw_expr) + if a is not None: + assignments.append(a) + + if not assignments: + raise ParseError(f"No assignments parsed for function {sol_fn_name!r}") + + # ------------------------------------------------------------------ + # Post-build validation: ensure the return variable(s) were recognized. + # If demangle_var failed to match a return variable's naming + # pattern, the model would silently lose the output. + # ------------------------------------------------------------------ + return_names_list: list[str] = [] + for ret_var in yf.rets: + return_clean = var_map.get(ret_var) + if return_clean is None: + raise ParseError( + f"Return variable {ret_var!r} of {sol_fn_name!r} was not " + f"recognized as a real variable by demangle_var. The compiler " + f"naming convention may have changed. Current patterns: " + f"var__ for param/return, usr$ for locals." + ) + # Use the final (possibly SSA-renamed) var_map entry. + return_names_list.append(var_map[ret_var]) + + model = FunctionModel( + fn_name=sol_fn_name, + assignments=tuple(assignments), + param_names=param_names, + return_names=tuple(return_names_list), + ) + + # ------------------------------------------------------------------ + # Lazy memory folding: resolve mload(addr) against __mstore(addr, val) + # synthetic assignments that were injected during inlining. + # ------------------------------------------------------------------ + # Build a lookup of Lean variable names → constant IntLit values + # from the model's assignments. This lets us resolve addresses + # like Var('x_1') that refer to Solidity locals (which live in + # `assignments`, not `subst`). + _const_locals: dict[str, int] = {} + for a in assignments: + if isinstance(a, Assignment) and isinstance(a.expr, IntLit): + _const_locals[a.target] = a.expr.value + + def _resolve_addr(expr: Expr) -> Expr: + """Resolve Var references through _const_locals before const-eval.""" + if isinstance(expr, Var) and expr.name in _const_locals: + return IntLit(_const_locals[expr.name]) + if isinstance(expr, Call): + new_args = tuple(_resolve_addr(a) for a in expr.args) + return Call(expr.name, new_args) + return expr + + # Collect __mstore entries from the copy-propagation subst dict. + # These have the form: subst[_inline___mstore_N] = Call("__mstore", [addr, val]) + # After copy propagation, addr and val are fully resolved. + # Collect __mstore entries, resolving addresses to integer constants. + mem_map: dict[int, Expr] = {} + for key, val in subst.items(): + if ( + isinstance(val, Call) + and val.name == "__mstore" + and len(val.args) == 2 + ): + addr = _try_const_eval(_resolve_addr(val.args[0])) + if addr is None: + raise ParseError( + f"__mstore synthetic assignment {key!r} has non-constant " + f"address {val.args[0]!r} after copy propagation. " + f"All mstore addresses must evaluate to constants " + f"(use tmp() in wrappers)." + ) + mem_map[addr] = val.args[1] + + if mem_map: + # Resolve mload calls within mem_map values against the same + # mem_map. This handles cases where e.g. the value at addr 0 + # contains mload(0x1080) which maps to x_hi from mem_map[4224]. + # Iterate until stable (acyclic references converge in one pass). + def _fold_mem_val(expr: Expr) -> Expr: + if isinstance(expr, (IntLit, Var)): + return expr + if isinstance(expr, Call): + if expr.name == "mload" and len(expr.args) == 1: + addr = _try_const_eval(_resolve_addr(expr.args[0])) + if addr is not None and addr in mem_map: + return mem_map[addr] + new_args = tuple(_fold_mem_val(a) for a in expr.args) + return Call(expr.name, new_args) + return expr + + changed = True + for _pass in range(5): + if not changed: + break + changed = False + for addr in list(mem_map.keys()): + new_val = _fold_mem_val(mem_map[addr]) + if new_val is not mem_map[addr]: + mem_map[addr] = new_val + changed = True + + model = _resolve_mloads(model, mem_map, _const_locals, sol_fn_name) + + return model + + +def _resolve_mloads( + model: "FunctionModel", + mem_map: dict[int, Expr], + const_locals: dict[str, int], + fn_name: str, +) -> "FunctionModel": + """Replace ``mload(const_addr)`` calls in a FunctionModel with values + from the memory map. + + Raises ``ParseError`` if any ``mload`` has a non-constant address or + an address not found in the memory map. + """ + def _resolve_addr(expr: Expr) -> Expr: + """Resolve Var references through const_locals before const-eval.""" + if isinstance(expr, Var) and expr.name in const_locals: + return IntLit(const_locals[expr.name]) + if isinstance(expr, Call): + new_args = tuple(_resolve_addr(a) for a in expr.args) + return Call(expr.name, new_args) + return expr + + def _fold(expr: Expr) -> Expr: + if isinstance(expr, (IntLit, Var)): + return expr + if isinstance(expr, Call): + if expr.name == "mload" and len(expr.args) == 1: + addr = _try_const_eval(_resolve_addr(expr.args[0])) + if addr is None: + raise ParseError( + f"mload with non-constant address {expr.args[0]!r} " + f"in {fn_name!r} after copy propagation. " + f"All mload addresses must evaluate to constants." + ) + if addr not in mem_map: + raise ParseError( + f"mload at address {addr} in {fn_name!r} has no " + f"matching mstore. Available addresses: " + f"{sorted(mem_map.keys())}" + ) + return mem_map[addr] + new_args = tuple(_fold(a) for a in expr.args) + return Call(expr.name, new_args) + return expr + + def _fold_stmt(stmt: ModelStatement) -> ModelStatement: + if isinstance(stmt, Assignment): + return Assignment(target=stmt.target, expr=_fold(stmt.expr)) + if isinstance(stmt, ConditionalBlock): + ea = None + if stmt.else_assignments is not None: + ea = tuple( + Assignment(target=a.target, expr=_fold(a.expr)) + for a in stmt.else_assignments + ) + return ConditionalBlock( + condition=_fold(stmt.condition), + assignments=tuple( + Assignment(target=a.target, expr=_fold(a.expr)) + for a in stmt.assignments + ), + modified_vars=stmt.modified_vars, + else_vars=stmt.else_vars, + else_assignments=ea, + ) + raise TypeError(f"Unsupported ModelStatement: {type(stmt)}") + + return FunctionModel( + fn_name=model.fn_name, + assignments=tuple(_fold_stmt(s) for s in model.assignments), + param_names=model.param_names, + return_names=model.return_names, + ) + + +# --------------------------------------------------------------------------- +# Lean emission helpers +# --------------------------------------------------------------------------- + +OP_TO_LEAN_HELPER = { + "add": "evmAdd", + "sub": "evmSub", + "mul": "evmMul", + "div": "evmDiv", + "mod": "evmMod", + "not": "evmNot", + "or": "evmOr", + "and": "evmAnd", + "eq": "evmEq", + "shl": "evmShl", + "shr": "evmShr", + "clz": "evmClz", + "lt": "evmLt", + "gt": "evmGt", + "mulmod": "evmMulmod", +} + +OP_TO_OPCODE = { + "add": "ADD", + "sub": "SUB", + "mul": "MUL", + "div": "DIV", + "mod": "MOD", + "not": "NOT", + "or": "OR", + "and": "AND", + "eq": "EQ", + "shl": "SHL", + "shr": "SHR", + "clz": "CLZ", + "lt": "LT", + "gt": "GT", + "mulmod": "MULMOD", +} + +# Base norm helpers shared by all generators. Per-generator extras (like +# bitLengthPlus1 for cbrt) are merged in via ModelConfig.extra_norm_ops. +_BASE_NORM_HELPERS = { + "add": "normAdd", + "sub": "normSub", + "mul": "normMul", + "div": "normDiv", + "mod": "normMod", + "not": "normNot", + "or": "normOr", + "and": "normAnd", + "eq": "normEq", + "shl": "normShl", + "shr": "normShr", + "clz": "normClz", + "lt": "normLt", + "gt": "normGt", + "mulmod": "normMulmod", +} + + +def validate_ident(name: str, *, what: str) -> None: + if not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", name): + raise ParseError(f"Invalid {what}: {name!r}") + + +def collect_ops(expr: Expr) -> list[str]: + out: list[str] = [] + if isinstance(expr, Call): + if expr.name in OP_TO_OPCODE: + out.append(expr.name) + for arg in expr.args: + out.extend(collect_ops(arg)) + return out + + +def collect_ops_from_statement(stmt: ModelStatement) -> list[str]: + """Collect opcodes from an Assignment or ConditionalBlock.""" + if isinstance(stmt, Assignment): + return collect_ops(stmt.expr) + if isinstance(stmt, ConditionalBlock): + ops = collect_ops(stmt.condition) + for a in stmt.assignments: + ops.extend(collect_ops(a.expr)) + return ops + raise TypeError(f"Unsupported ModelStatement: {type(stmt)}") + + +def ordered_unique(items: list[str]) -> list[str]: + seen: set[str] = set() + out: list[str] = [] + for item in items: + if item in seen: + continue + seen.add(item) + out.append(item) + return out + + +def emit_expr( + expr: Expr, + *, + op_helper_map: dict[str, str], + call_helper_map: dict[str, str], +) -> str: + if isinstance(expr, IntLit): + return str(expr.value) + if isinstance(expr, Var): + return expr.name + if isinstance(expr, Call): + # Handle __component_N(call) for multi-return function calls. + # Emits Lean tuple projection: (f args).1 for component 0, etc. + m = re.fullmatch(r"__component_(\d+)", expr.name) + if m and len(expr.args) == 1: + idx = int(m.group(1)) + inner = emit_expr(expr.args[0], op_helper_map=op_helper_map, call_helper_map=call_helper_map) + return f"({inner}).{idx + 1}" + + # Handle __ite(cond, if_val, else_val) from leave-handling. + # Emits: if (cond) ≠ 0 then if_val else else_val + if expr.name == "__ite" and len(expr.args) == 3: + cond = emit_expr(expr.args[0], op_helper_map=op_helper_map, call_helper_map=call_helper_map) + if_val = emit_expr(expr.args[1], op_helper_map=op_helper_map, call_helper_map=call_helper_map) + else_val = emit_expr(expr.args[2], op_helper_map=op_helper_map, call_helper_map=call_helper_map) + return f"if ({cond}) ≠ 0 then {if_val} else {else_val}" + + helper = op_helper_map.get(expr.name) + if helper is None: + helper = call_helper_map.get(expr.name) + if helper is None: + raise ParseError(f"Unsupported call in Lean emitter: {expr.name!r}") + args = " ".join(f"({emit_expr(a, op_helper_map=op_helper_map, call_helper_map=call_helper_map)})" for a in expr.args) + return f"{helper} {args}".rstrip() + raise TypeError(f"Unsupported Expr node: {type(expr)}") + + +# --------------------------------------------------------------------------- +# Per-generator configuration +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ModelConfig: + """All the per-library knobs that differ between cbrt and sqrt generators.""" + # Ordered Solidity function names to model. + function_order: tuple[str, ...] + # sol_fn_name → Lean model base name (e.g. "_cbrt" → "model_cbrt") + model_names: dict[str, str] + # Lean header line (e.g. "Auto-generated from Solidity Cbrt assembly …") + header_comment: str + # Generator script path for the header (e.g. "formal/cbrt/generate_cbrt_model.py") + generator_label: str + # Additional norm-helper entries beyond the base set. + extra_norm_ops: dict[str, str] + # Additional Lean definitions emitted right before normLt/normGt. + extra_lean_defs: str + # Optional AST rewrite applied to expressions in the Nat model. + norm_rewrite: Callable[[Expr], Expr] | None + # Inner function name that the public functions depend on. + inner_fn: str + # Optional per-function expected parameter counts for disambiguation. + # When set, find_function uses param count to pick among homonymous + # Yul functions (e.g. single-param _sqrt vs two-param _sqrt). + n_params: dict[str, int] | None = None + # When True, variables matching var__ (Solidity-declared + # locals) are kept in the model instead of being copy-propagated. + # Needed for functions with mixed assembly + Solidity code. + keep_solidity_locals: bool = False + # Function names whose find_function should use exclude_known=True, + # i.e. prefer candidates that do NOT reference already-targeted + # functions. Used to select leaf functions (e.g. 256-bit Sqrt.sqrt) + # over higher-level wrappers with the same name. + exclude_known: frozenset[str] = frozenset() + # Function names for which the normalized (unbounded Nat) model + # variation should be suppressed. The norm model uses normShl/normMul + # etc. which do NOT match EVM uint256 semantics. For wrapper functions + # whose proofs bridge the EVM model directly, the norm model is unused. + skip_norm: frozenset[str] = frozenset() + + # -- CLI defaults -- + default_source_label: str = "" + default_namespace: str = "" + default_output: str = "" + cli_description: str = "" + + +# --------------------------------------------------------------------------- +# High-level pipeline (shared by both generators) +# --------------------------------------------------------------------------- + + +def build_model_body( + assignments: tuple[ModelStatement, ...], + *, + evm: bool, + config: ModelConfig, + param_names: tuple[str, ...] = ("x",), + return_names: tuple[str, ...] = ("z",), +) -> str: + lines: list[str] = [] + norm_helpers = {**_BASE_NORM_HELPERS, **config.extra_norm_ops} + + if evm: + for p in param_names: + lines.append(f" let {p} := u256 {p}") + call_map = {fn: f"{config.model_names[fn]}_evm" for fn in config.function_order} + op_map = OP_TO_LEAN_HELPER + else: + call_map = dict(config.model_names) + op_map = norm_helpers + + def _emit_rhs(expr: Expr) -> str: + rhs_expr = expr + if not evm and config.norm_rewrite is not None: + rhs_expr = config.norm_rewrite(rhs_expr) + return emit_expr(rhs_expr, op_helper_map=op_map, call_helper_map=call_map) + + for stmt in assignments: + if isinstance(stmt, ConditionalBlock): + # Emit Lean tuple-destructuring if-then-else: + # let (v1, v2) := if cond ≠ 0 then + # let v1 := ... + # ... + # (v1, v2) + # else (v1, v2) + cond_str = _emit_rhs(stmt.condition) + mvars = stmt.modified_vars + evars = stmt.else_vars if stmt.else_vars is not None else mvars + if len(mvars) == 1: + lhs = mvars[0] + tup = mvars[0] + else: + lhs = f"({', '.join(mvars)})" + tup = f"({', '.join(mvars)})" + if len(evars) == 1: + else_tup = evars[0] + else: + else_tup = f"({', '.join(evars)})" + lines.append(f" let {lhs} := if ({cond_str}) ≠ 0 then") + for a in stmt.assignments: + rhs = _emit_rhs(a.expr) + lines.append(f" let {a.target} := {rhs}") + lines.append(f" {tup}") + if stmt.else_assignments is not None: + lines.append(f" else") + for a in stmt.else_assignments: + rhs = _emit_rhs(a.expr) + lines.append(f" let {a.target} := {rhs}") + # Build the else tuple from the else-body's modified vars. + # Variables in modified_vars but not assigned in else_body + # keep their pre-if name. + else_assigned = {a.target for a in stmt.else_assignments} + else_tuple_parts = [] + for v in mvars: + if v in else_assigned: + else_tuple_parts.append(v) + elif evars is not None: + idx = list(mvars).index(v) + else_tuple_parts.append(evars[idx] if idx < len(evars) else v) + else: + else_tuple_parts.append(v) + if len(else_tuple_parts) == 1: + lines.append(f" {else_tuple_parts[0]}") + else: + lines.append(f" ({', '.join(else_tuple_parts)})") + else: + lines.append(f" else {else_tup}") + elif isinstance(stmt, Assignment): + rhs = _emit_rhs(stmt.expr) + lines.append(f" let {stmt.target} := {rhs}") + else: + raise TypeError(f"Unsupported ModelStatement: {type(stmt)}") + + if len(return_names) == 1: + lines.append(f" {return_names[0]}") + else: + lines.append(f" ({', '.join(return_names)})") + return "\n".join(lines) + + +def render_function_defs(models: list[FunctionModel], config: ModelConfig) -> str: + parts: list[str] = [] + for model in models: + model_base = config.model_names[model.fn_name] + evm_name = f"{model_base}_evm" + evm_body = build_model_body( + model.assignments, evm=True, config=config, + param_names=model.param_names, return_names=model.return_names, + ) + + param_sig = " ".join(f"{p}" for p in model.param_names) + if len(model.return_names) == 1: + ret_type = "Nat" + else: + ret_type = " × ".join("Nat" for _ in model.return_names) + parts.append( + f"/-- Opcode-faithful auto-generated model of `{model.fn_name}` with uint256 EVM semantics. -/\n" + f"def {evm_name} ({param_sig} : Nat) : {ret_type} :=\n" + f"{evm_body}\n" + ) + if model.fn_name not in config.skip_norm: + norm_name = model_base + norm_body = build_model_body( + model.assignments, evm=False, config=config, + param_names=model.param_names, return_names=model.return_names, + ) + parts.append( + f"/-- Normalized auto-generated model of `{model.fn_name}` on Nat arithmetic. -/\n" + f"def {norm_name} ({param_sig} : Nat) : {ret_type} :=\n" + f"{norm_body}\n" + ) + return "\n".join(parts) + + +def build_lean_source( + *, + models: list[FunctionModel], + source_path: str, + namespace: str, + config: ModelConfig, +) -> str: + generated_at = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + modeled_functions = ", ".join(model.fn_name for model in models) + + raw_ops: list[str] = [] + for model in models: + for stmt in model.assignments: + raw_ops.extend(collect_ops_from_statement(stmt)) + opcodes = ordered_unique([OP_TO_OPCODE[name] for name in raw_ops]) + opcodes_line = ", ".join(opcodes) + + function_defs = render_function_defs(models, config) + + src = ( + "import Init\n\n" + f"namespace {namespace}\n\n" + f"/-- {config.header_comment} -/\n" + f"-- Source: {source_path}\n" + f"-- Modeled functions: {modeled_functions}\n" + f"-- Generated by: {config.generator_label}\n" + f"-- Generated at (UTC): {generated_at}\n" + f"-- Modeled opcodes/Yul builtins: {opcodes_line}\n\n" + "def WORD_MOD : Nat := 2 ^ 256\n\n" + "def u256 (x : Nat) : Nat :=\n" + " x % WORD_MOD\n\n" + "def evmAdd (a b : Nat) : Nat :=\n" + " u256 (u256 a + u256 b)\n\n" + "def evmSub (a b : Nat) : Nat :=\n" + " u256 (u256 a + WORD_MOD - u256 b)\n\n" + "def evmMul (a b : Nat) : Nat :=\n" + " u256 (u256 a * u256 b)\n\n" + "def evmDiv (a b : Nat) : Nat :=\n" + " let aa := u256 a\n" + " let bb := u256 b\n" + " if bb = 0 then 0 else aa / bb\n\n" + "def evmMod (a b : Nat) : Nat :=\n" + " let aa := u256 a\n" + " let bb := u256 b\n" + " if bb = 0 then 0 else aa % bb\n\n" + "def evmNot (a : Nat) : Nat :=\n" + " WORD_MOD - 1 - u256 a\n\n" + "def evmOr (a b : Nat) : Nat :=\n" + " u256 a ||| u256 b\n\n" + "def evmAnd (a b : Nat) : Nat :=\n" + " u256 a &&& u256 b\n\n" + "def evmEq (a b : Nat) : Nat :=\n" + " if u256 a = u256 b then 1 else 0\n\n" + "def evmShl (shift value : Nat) : Nat :=\n" + " let s := u256 shift\n" + " let v := u256 value\n" + " if s < 256 then u256 (v * 2 ^ s) else 0\n\n" + "def evmShr (shift value : Nat) : Nat :=\n" + " let s := u256 shift\n" + " let v := u256 value\n" + " if s < 256 then v / 2 ^ s else 0\n\n" + "def evmClz (value : Nat) : Nat :=\n" + " let v := u256 value\n" + " if v = 0 then 256 else 255 - Nat.log2 v\n\n" + "def evmLt (a b : Nat) : Nat :=\n" + " if u256 a < u256 b then 1 else 0\n\n" + "def evmGt (a b : Nat) : Nat :=\n" + " if u256 a > u256 b then 1 else 0\n\n" + "def evmMulmod (a b n : Nat) : Nat :=\n" + " let aa := u256 a; let bb := u256 b; let nn := u256 n\n" + " if nn = 0 then 0 else (aa * bb) % nn\n\n" + "def normAdd (a b : Nat) : Nat := a + b\n\n" + "def normSub (a b : Nat) : Nat := a - b\n\n" + "def normMul (a b : Nat) : Nat := a * b\n\n" + "def normDiv (a b : Nat) : Nat := a / b\n\n" + "def normMod (a b : Nat) : Nat := a % b\n\n" + "def normNot (a : Nat) : Nat := WORD_MOD - 1 - a\n\n" + "def normOr (a b : Nat) : Nat := a ||| b\n\n" + "def normAnd (a b : Nat) : Nat := a &&& b\n\n" + "def normEq (a b : Nat) : Nat :=\n" + " if a = b then 1 else 0\n\n" + "def normShl (shift value : Nat) : Nat := value <<< shift\n\n" + "def normShr (shift value : Nat) : Nat := value / 2 ^ shift\n\n" + "def normClz (value : Nat) : Nat :=\n" + " if value = 0 then 256 else 255 - Nat.log2 value\n\n" + f"{config.extra_lean_defs}" + "def normLt (a b : Nat) : Nat :=\n" + " if a < b then 1 else 0\n\n" + "def normGt (a b : Nat) : Nat :=\n" + " if a > b then 1 else 0\n\n" + "def normMulmod (a b n : Nat) : Nat :=\n" + " if n = 0 then 0 else (a * b) % n\n\n" + f"{function_defs}\n" + f"end {namespace}\n" + ) + return src + + +def parse_function_selection( + args: argparse.Namespace, + config: ModelConfig, +) -> tuple[str, ...]: + selected: list[str] = [] + + if args.function: + selected.extend(args.function) + if args.functions: + for fn in args.functions.split(","): + name = fn.strip() + if name: + selected.append(name) + + if not selected: + selected = list(config.function_order) + + allowed = set(config.function_order) + bad = [f for f in selected if f not in allowed] + if bad: + raise ParseError(f"Unsupported function(s): {', '.join(bad)}") + + # Public functions depend on the inner function. + if any(fn != config.inner_fn for fn in selected) and config.inner_fn not in selected: + selected.append(config.inner_fn) + + selected_set = set(selected) + return tuple(fn for fn in config.function_order if fn in selected_set) + + +def run(config: ModelConfig) -> int: + """Main entry point shared by both generators.""" + ap = argparse.ArgumentParser(description=config.cli_description) + ap.add_argument( + "--yul", required=True, + help="Path to Yul IR file, or '-' for stdin (from `forge inspect ... ir`)", + ) + ap.add_argument( + "--source-label", default=config.default_source_label, + help="Source label for the Lean header comment", + ) + ap.add_argument( + "--functions", default="", + help=f"Comma-separated function names (default: {','.join(config.function_order)})", + ) + ap.add_argument( + "--function", action="append", + help="Optional repeatable function selector", + ) + ap.add_argument( + "--namespace", default=config.default_namespace, + help="Lean namespace for generated definitions", + ) + ap.add_argument( + "--output", default=config.default_output, + help="Output Lean file path", + ) + args = ap.parse_args() + + validate_ident(args.namespace, what="Lean namespace") + + selected_functions = parse_function_selection(args, config) + + if args.yul == "-": + yul_text = sys.stdin.read() + else: + yul_text = pathlib.Path(args.yul).read_text() + + tokens = tokenize_yul(yul_text) + + # Collect all parseable function definitions for inlining. + fn_table = YulParser(tokens).collect_all_functions() + + fn_map: dict[str, str] = {} + yul_functions: dict[str, YulFunction] = {} + + # First pass: find target functions and record their Yul names. + known_yul_names: set[str] = set() + for sol_name in selected_functions: + p = YulParser(tokens) + np = config.n_params.get(sol_name) if config.n_params else None + yf = p.find_function(sol_name, n_params=np, + known_yul_names=known_yul_names or None, + exclude_known=sol_name in config.exclude_known) + fn_map[yf.yul_name] = sol_name + yul_functions[sol_name] = yf + known_yul_names.add(yf.yul_name) + + # Remove target functions from the inlining table so they remain + # as named calls in the model (e.g. sqrt calling _sqrt → model_sqrt). + for yul_name in fn_map: + fn_table.pop(yul_name, None) + + if fn_table: + print(f"Collected {len(fn_table)} function definition(s) for inlining") + + # Second pass: inline non-target function calls. + for sol_name in selected_functions: + yf = yul_functions[sol_name] + yf = _inline_yul_function(yf, fn_table) + yul_functions[sol_name] = yf + + models = [ + yul_function_to_model( + yul_functions[fn], fn, fn_map, + keep_solidity_locals=config.keep_solidity_locals, + ) + for fn in selected_functions + ] + + lean_src = build_lean_source( + models=models, + source_path=args.source_label, + namespace=args.namespace, + config=config, + ) + + out_path = pathlib.Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(lean_src) + + print(f"Generated {out_path}") + for model in models: + print(f"Parsed {len(model.assignments)} assignments for {model.fn_name}") + + raw_ops: list[str] = [] + for model in models: + for stmt in model.assignments: + raw_ops.extend(collect_ops_from_statement(stmt)) + opcodes = ordered_unique([OP_TO_OPCODE[name] for name in raw_ops]) + print(f"Modeled opcodes: {', '.join(opcodes)}") + + return 0 diff --git a/foundry.toml b/foundry.toml index d86543996..3f2d8ff65 100644 --- a/foundry.toml +++ b/foundry.toml @@ -15,7 +15,7 @@ optimizer_runs = 2_000 evm_version = "osaka" gas_limit = 16_777_216 block_gas_limit = 16_777_216 -no_match_path = "test/integration/*" +no_match_path = "{test/integration/*,test/0.8.25/formal-model/*}" # needed for marktoda/forge-gas-snapshot ffi = true fs_permissions = [{ access = "read-write", path = ".forge-snapshots/" }, { access = "read", path = "out" }, { access = "read", path = "script/" }] @@ -36,6 +36,11 @@ block_gas_limit = 50_000_000 code_size_limit = 200_000 disable_block_gas_limit = true +[profile.formal-model] +match_path = "test/0.8.25/formal-model/*" +no_match_path = "😡 toml has no null value 😡" +fuzz.runs = 20_000 + [fuzz] runs = 100_000 max_test_rejects = 1_000_000 diff --git a/src/utils/512Math.sol b/src/utils/512Math.sol index 0c22df9ae..59ff4f4bf 100644 --- a/src/utils/512Math.sol +++ b/src/utils/512Math.sol @@ -1696,10 +1696,109 @@ library Lib512MathArithmetic { return omodAlt(r, y, r); } + //// The following 512-bit square root implementation is a realization of Zimmerman's "Karatsuba + //// Square Root" algorithm https://inria.hal.science/inria-00072854/document . This approach is + //// inspired by https://github.com/SimonSuckut/Solidity_Uint512/ . These helper functions are + //// broken out separately to ease formal verification. + + /// One square root Babylonian step: r = ⌊(x/r + r) / 2⌋ + function _sqrt_babylonianStep(uint256 x, uint256 r) private pure returns (uint256 r_out) { + unchecked { + return x.unsafeDiv(r) + r >> 1; + } + } + + /// 6 Babylonian steps from fixed seed + floor correction + residue for Karatsuba + /// + /// Implementing this as: + /// uint256 r_hi = x_hi.sqrt(); + /// uint256 res = x_hi - r_hi * r_hi; + /// is correct, but duplicates the normalization that we do in `_sqrt` and performs a + /// more-costly initialization step. solc is not very smart. It can't optimize away the + /// initialization step of `Sqrt.sqrt`. It also can't optimize the calculation of `res`, so + /// doing it in Yul is meaningfully more gas efficient. + function _sqrt_baseCase(uint256 x_hi) private pure returns (uint256 r_hi, uint256 res) { + // Seed with √(2²⁵⁵), the geometric mean of the normalized √x_hi range [2¹²⁷, + // 2¹²⁸). This balances worst-case over/underestimate (ε ≈ ±0.414/0.293), giving + // >128 bits of precision in 6 Babylonian steps + r_hi = 0xb504f333f9de6484597d89b3754abe9f; + + // 6 Babylonian steps is sufficient for convergence + r_hi = _sqrt_babylonianStep(x_hi, r_hi); + r_hi = _sqrt_babylonianStep(x_hi, r_hi); + r_hi = _sqrt_babylonianStep(x_hi, r_hi); + r_hi = _sqrt_babylonianStep(x_hi, r_hi); + r_hi = _sqrt_babylonianStep(x_hi, r_hi); + r_hi = _sqrt_babylonianStep(x_hi, r_hi); + + // The Babylonian step can oscillate between ⌊√x_hi⌋ and ⌈√x_hi⌉. Clean that up. + r_hi = r_hi.unsafeDec(x_hi.unsafeDiv(r_hi) < r_hi); + + assembly ("memory-safe") { + // This is cheaper than + // unchecked { + // uint256 res = x_hi - r_hi * r_hi; + // } + // for no clear reason + res := sub(x_hi, mul(r_hi, r_hi)) + } + } + + /// Karatsuba quotient with carry correction + /// + /// `res` is (almost) a single limb. Create a new (almost) machine word `n` with `res` as + /// the upper limb and shifting in the next limb of `x` (namely `x_lo >> 128`) as the + /// lower limb. The next step of Zimmerman's algorithm is: + /// r_lo = n / (2 · r_hi) + /// res = n % (2 · r_hi) + function _sqrt_karatsubaQuotient(uint256 res, uint256 x_lo, uint256 r_hi) + private + pure + returns (uint256 r_lo, uint256 res_out) + { + assembly ("memory-safe") { + let n := or(shl(0x80, res), shr(0x80, x_lo)) + let d := shl(0x01, r_hi) + r_lo := div(n, d) + + let c := shr(0x80, res) + res_out := mod(n, d) + + // It's possible that `n` was 257 bits and overflowed (`res` was not just a single + // limb). Explicitly handling the carry avoids 512-bit division. + if c { + r_lo := add(r_lo, div(not(0x00), d)) + res_out := add(res_out, add(0x01, mod(not(0x00), d))) + r_lo := add(r_lo, div(res_out, d)) + res_out := mod(res_out, d) + } + } + } + + /// Combine `r_hi` with `r_lo` and perform the 257-bit underflow correction + /// + /// The final step of Zimmerman's algorithm is: if res · 2¹²⁸ + x_lo % 2¹²⁸ < r_lo², decrement + /// `r`. We have to do this in a complicated manner because both `res` and `r_lo` can be + /// 𝑠𝑙𝑖𝑔ℎ𝑡𝑙𝑦 longer than 1 limb (128 bits). This is more efficient than performing the full + /// 257-bit comparison. + function _sqrt_correction(uint256 r_hi, uint256 r_lo, uint256 res, uint256 x_lo) + private + pure + returns (uint256 r) + { + unchecked { + r = (r_hi << 128) + r_lo; + r = r.unsafeDec( + ((res >> 128) < (r_lo >> 128)) + .or( + ((res >> 128) == (r_lo >> 128)) + .and((res << 128) | (x_lo & 0xffffffffffffffffffffffffffffffff) < r_lo * r_lo) + ) + ); + } + } + function _sqrt(uint256 x_hi, uint256 x_lo) private pure returns (uint256 r) { - /// Our general approach is to apply Zimmerman's "Karatsuba Square Root" algorithm - /// https://inria.hal.science/inria-00072854/document with the helpers from Solady and - /// 512Math. This approach is inspired by https://github.com/SimonSuckut/Solidity_Uint512/ unchecked { // Normalize `x` so the top word has its MSB in bit 255 or 254. This makes the "shift // back" step exact. @@ -1711,74 +1810,15 @@ library Lib512MathArithmetic { // We treat `r` as a ≤2-limb bigint where each limb is half a machine word (128 bits). // Spliting √x in this way lets us apply "ordinary" 256-bit `sqrt` to the top word of // `x`. Then we can recover the bottom limb of `r` without 512-bit division. - // - // Implementing this as: - // uint256 r_hi = x_hi.sqrt(); - // is correct, but duplicates the normalization that we just did above and performs a - // more-costly initialization step. solc is not smart enough to optimize this away, so - // we inline and do it ourselves. - uint256 r_hi; - assembly ("memory-safe") { - // Seed with √(2²⁵⁵), the geometric mean of the normalized √x_hi range [2¹²⁷, - // 2¹²⁸). This balances worst-case over/underestimate (ε ≈ ±0.414/0.293), giving - // >128 bits of precision in 6 Babylonian steps - r_hi := 0xb504f333f9de6484597d89b3754abe9f - - // 6 Babylonian steps is sufficient for convergence - r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) - r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) - r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) - r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) - r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) - r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) - - // The Babylonian step can oscillate between ⌊√x_hi⌋ and ⌈√x_hi⌉. Clean that up. - r_hi := sub(r_hi, lt(div(x_hi, r_hi), r_hi)) - } - - // This is cheaper than - // uint256 res = x_hi - r_hi * r_hi; - // for no clear reason - uint256 res; - assembly ("memory-safe") { - res := sub(x_hi, mul(r_hi, r_hi)) - } + (uint256 r_hi, uint256 res) = _sqrt_baseCase(x_hi); + // The next titular Karatsuba step extends the upper limb of `r` to approximate the + // lower limb. uint256 r_lo; - // `res` is (almost) a single limb. Create a new (almost) machine word `n` with `res` as - // the upper limb and shifting in the next limb of `x` (namely `x_lo >> 128`) as the - // lower limb. The next step of Zimmerman's algorithm is: - // r_lo = n / (2 · r_hi) - // res = n % (2 · r_hi) - assembly ("memory-safe") { - let n := or(shl(0x80, res), shr(0x80, x_lo)) - let d := shl(0x01, r_hi) - r_lo := div(n, d) + (r_lo, res) = _sqrt_karatsubaQuotient(res, x_lo, r_hi); - let c := shr(0x80, res) - res := mod(n, d) - - // It's possible that `n` was 257 bits and overflowed (`res` was not just a single - // limb). Explicitly handling the carry avoids 512-bit division. - if c { - r_lo := add(r_lo, div(not(0x00), d)) - res := add(res, add(0x01, mod(not(0x00), d))) - r_lo := add(r_lo, div(res, d)) - res := mod(res, d) - } - } - r = (r_hi << 128) + r_lo; - - // Then, if res · 2¹²⁸ + x_lo % 2¹²⁸ < r_lo², decrement `r`. We have to do this in a - // complicated manner because both `res` and `r_lo` can be _slightly_ longer than 1 limb - // (128 bits). This is more efficient than performing the full 257-bit comparison. - r = r.unsafeDec( - ((res >> 128) < (r_lo >> 128)) - .or( - ((res >> 128) == (r_lo >> 128)) - .and((res << 128) | (x_lo & 0xffffffffffffffffffffffffffffffff) < r_lo * r_lo) - ) - ); + // The Karatsuba step is an approximation. This refinement makes it exactly ⌊√x⌋ + r = _sqrt_correction(r_hi, r_lo, res, x_lo); // Un-normalize return r >> shift; @@ -1798,15 +1838,16 @@ library Lib512MathArithmetic { function osqrtUp(uint512 r, uint512 x) internal pure returns (uint512) { (uint256 x_hi, uint256 x_lo) = x.into(); + uint256 r_hi; + uint256 r_lo; if (x_hi == 0) { - return r.from(0, x_lo.sqrtUp()); + r_lo = x_lo.sqrtUp(); + } else { + r_lo = _sqrt(x_hi, x_lo); + (uint256 r2_hi, uint256 r2_lo) = _mul(r_lo, r_lo); + (r_hi, r_lo) = _add(0, r_lo, _gt(x_hi, x_lo, r2_hi, r2_lo).toUint()); } - uint256 r_lo = _sqrt(x_hi, x_lo); - - (uint256 r2_hi, uint256 r2_lo) = _mul(r_lo, r_lo); - uint256 r_hi; - (r_hi, r_lo) = _add(0, r_lo, _gt(x_hi, x_lo, r2_hi, r2_lo).toUint()); return r.from(r_hi, r_lo); } @@ -2117,7 +2158,7 @@ struct uint512_external { library Lib512MathExternal { function from(uint512 r, uint512_external memory x) internal pure returns (uint512) { assembly ("memory-safe") { - // This *could* be done with `mcopy`, but that would mean giving up compatibility with + // This 𝐜𝐨𝐮𝐥𝐝 be done with `mcopy`, but that would mean giving up compatibility with // Shanghai (or less) chains. If you care about gas efficiency, you should be using // `into()` instead. mstore(r, mload(x)) diff --git a/src/vendor/Cbrt.sol b/src/vendor/Cbrt.sol index 0ace38327..520d3a3bf 100644 --- a/src/vendor/Cbrt.sol +++ b/src/vendor/Cbrt.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -pragma solidity ^0.8.25; +pragma solidity ^0.8.33; // @author Modified from Solady by Vectorized and Akshay Tarpara https://github.com/Vectorized/solady/blob/ff6256a18851749e765355b3e21dc9bfa417255b/src/utils/clz/FixedPointMathLib.sol#L799-L822 under the MIT license. library Cbrt { diff --git a/src/vendor/Sqrt.sol b/src/vendor/Sqrt.sol index c5e9b81a0..b51eac9da 100644 --- a/src/vendor/Sqrt.sol +++ b/src/vendor/Sqrt.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -pragma solidity ^0.8.25; +pragma solidity ^0.8.33; // @author Modified from Solady by Vectorized and Akshay Tarpara https://github.com/Vectorized/solady/blob/1198c9f70b30d472a7d0ec021bec080622191b03/src/utils/clz/FixedPointMathLib.sol#L769-L797 under the MIT license. library Sqrt { diff --git a/src/wrappers/CbrtWrapper.sol b/src/wrappers/CbrtWrapper.sol new file mode 100644 index 000000000..bad2aef35 --- /dev/null +++ b/src/wrappers/CbrtWrapper.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {Cbrt} from "src/vendor/Cbrt.sol"; + +/// @dev Thin wrapper exposing Cbrt's internal functions for `forge inspect ... ir`. +/// Function names are prefixed with `wrap_` to avoid Yul name collisions with the +/// library functions, keeping the IR unambiguous for the formal-proof code generator. +contract CbrtWrapper { + function wrap_cbrt(uint256 x) external pure returns (uint256) { + return Cbrt.cbrt(x); + } + function wrap_cbrtUp(uint256 x) external pure returns (uint256) { + return Cbrt.cbrtUp(x); + } +} diff --git a/src/wrappers/Sqrt512Wrapper.sol b/src/wrappers/Sqrt512Wrapper.sol new file mode 100644 index 000000000..8c4ee0305 --- /dev/null +++ b/src/wrappers/Sqrt512Wrapper.sol @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +pragma solidity =0.8.33; + +import {uint512, alloc, tmp} from "src/utils/512Math.sol"; + +/// @dev Thin wrapper exposing 512Math's sqrt functions for `forge inspect ... ir`. +/// The public `sqrt(uint512)` calls `_sqrt(x_hi, x_lo)` internally, so both +/// appear in the Yul IR. The driver script disambiguates by parameter count. +/// +/// The `wrap_sqrt512` and `wrap_sqrt512Up` functions use `alloc()` for their +/// own testing purposes. For model generation, `wrap_sqrt512` and +/// `wrap_osqrtUp` use `tmp()` (fixed address 0) so that the Yul IR's +/// mstore/mload pairs can be folded to direct parameter references by the +/// model generator. +contract Sqrt512Wrapper { + function wrap_sqrt512(uint256 x_hi, uint256 x_lo) external pure returns (uint256) { + return tmp().from(x_hi, x_lo).sqrt(); + } + + function wrap_osqrtUp(uint256 x_hi, uint256 x_lo) external pure returns (uint256, uint256) { + uint512 x; + assembly { // not "memory-safe" + x := 0x1080 + } + return tmp().osqrtUp(x.from(x_hi, x_lo)).into(); + } +} diff --git a/src/wrappers/SqrtWrapper.sol b/src/wrappers/SqrtWrapper.sol new file mode 100644 index 000000000..126ec470c --- /dev/null +++ b/src/wrappers/SqrtWrapper.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {Sqrt} from "src/vendor/Sqrt.sol"; + +/// @dev Thin wrapper exposing Sqrt's internal functions for `forge inspect ... ir`. +/// Function names are prefixed with `wrap_` to avoid Yul name collisions with the +/// library functions, keeping the IR unambiguous for the formal-proof code generator. +contract SqrtWrapper { + function wrap_sqrt(uint256 x) external pure returns (uint256) { + return Sqrt.sqrt(x); + } + function wrap_sqrtUp(uint256 x) external pure returns (uint256) { + return Sqrt.sqrtUp(x); + } +} diff --git a/test/0.8.25/Cbrt.t.sol b/test/0.8.25/Cbrt.t.sol index 189d3bd4c..c5638525c 100644 --- a/test/0.8.25/Cbrt.t.sol +++ b/test/0.8.25/Cbrt.t.sol @@ -12,8 +12,16 @@ contract CbrtTest is Test { uint256 private constant _CBRT_FLOOR_MAX_UINT256_CUBE = 0xffffffffffffffffffffef214b5539a2d22f71387253e480168f34c9da3f5898; - function testCbrt(uint256 x) external pure { - uint256 r = x.cbrt(); + function _cbrtFloor(uint256 x) internal virtual returns (uint256) { + return x.cbrt(); + } + + function _cbrtUp(uint256 x) internal virtual returns (uint256) { + return x.cbrtUp(); + } + + function testCbrt(uint256 x) external { + uint256 r = _cbrtFloor(x); assertLe(r * r * r, x, "cbrt too high"); if (x < _CBRT_FLOOR_MAX_UINT256_CUBE) { r++; @@ -23,8 +31,8 @@ contract CbrtTest is Test { } } - function testCbrtUp(uint256 x) external pure { - uint256 r = x.cbrtUp(); + function testCbrtUp(uint256 x) external { + uint256 r = _cbrtUp(x); if (x <= _CBRT_FLOOR_MAX_UINT256_CUBE) { assertGe(r * r * r, x, "cbrtUp too low"); } else { @@ -38,17 +46,17 @@ contract CbrtTest is Test { } } - function testCbrtUp_overflowCubeRange(uint256 x) external pure { + function testCbrtUp_overflowCubeRange(uint256 x) external { x = bound(x, _CBRT_FLOOR_MAX_UINT256_CUBE + 1, type(uint256).max); - assertEq(x.cbrt(), _CBRT_FLOOR_MAX_UINT256, "cbrt overflow-cube range"); - assertEq(x.cbrtUp(), _CBRT_CEIL_MAX_UINT256, "cbrtUp overflow-cube range"); + assertEq(_cbrtFloor(x), _CBRT_FLOOR_MAX_UINT256, "cbrt overflow-cube range"); + assertEq(_cbrtUp(x), _CBRT_CEIL_MAX_UINT256, "cbrtUp overflow-cube range"); } - function testCbrtUp_overflowCubeBoundary() external pure { + function testCbrtUp_overflowCubeBoundary() external { uint256 x = _CBRT_FLOOR_MAX_UINT256_CUBE; - assertEq(x.cbrt(), _CBRT_FLOOR_MAX_UINT256, "cbrt boundary"); - assertEq(x.cbrtUp(), _CBRT_FLOOR_MAX_UINT256, "cbrtUp boundary"); + assertEq(_cbrtFloor(x), _CBRT_FLOOR_MAX_UINT256, "cbrt boundary"); + assertEq(_cbrtUp(x), _CBRT_FLOOR_MAX_UINT256, "cbrtUp boundary"); } } diff --git a/test/0.8.25/Sqrt.t.sol b/test/0.8.25/Sqrt.t.sol index f9c051113..e1850cf05 100644 --- a/test/0.8.25/Sqrt.t.sol +++ b/test/0.8.25/Sqrt.t.sol @@ -12,8 +12,16 @@ contract SqrtTest is Test { uint256 private constant _SQRT_FLOOR_MAX_UINT256_SQUARED = 0xfffffffffffffffffffffffffffffffe00000000000000000000000000000001; - function testSqrt(uint256 x) external pure { - uint256 r = x.sqrt(); + function _sqrtFloor(uint256 x) internal virtual returns (uint256) { + return x.sqrt(); + } + + function _sqrtUp(uint256 x) internal virtual returns (uint256) { + return x.sqrtUp(); + } + + function testSqrt(uint256 x) external { + uint256 r = _sqrtFloor(x); assertLe(r * r, x, "sqrt too high"); if (x < _SQRT_FLOOR_MAX_UINT256_SQUARED) { r++; @@ -23,8 +31,8 @@ contract SqrtTest is Test { } } - function testSqrtUp(uint256 x) external pure { - uint256 r = x.sqrtUp(); + function testSqrtUp(uint256 x) external { + uint256 r = _sqrtUp(x); if (x <= _SQRT_FLOOR_MAX_UINT256_SQUARED) { assertGe(r * r, x, "sqrtUp too low"); } else { @@ -38,17 +46,17 @@ contract SqrtTest is Test { } } - function testSqrtUp_overflowSquareRange(uint256 x) external pure { + function testSqrtUp_overflowSquareRange(uint256 x) external { x = bound(x, _SQRT_FLOOR_MAX_UINT256_SQUARED + 1, type(uint256).max); - assertEq(x.sqrt(), _SQRT_FLOOR_MAX_UINT256, "sqrt overflow-square range"); - assertEq(x.sqrtUp(), _SQRT_CEIL_MAX_UINT256, "sqrtUp overflow-square range"); + assertEq(_sqrtFloor(x), _SQRT_FLOOR_MAX_UINT256, "sqrt overflow-square range"); + assertEq(_sqrtUp(x), _SQRT_CEIL_MAX_UINT256, "sqrtUp overflow-square range"); } - function testSqrtUp_overflowSquareBoundary() external pure { + function testSqrtUp_overflowSquareBoundary() external { uint256 x = _SQRT_FLOOR_MAX_UINT256_SQUARED; - assertEq(x.sqrt(), _SQRT_FLOOR_MAX_UINT256, "sqrt boundary"); - assertEq(x.sqrtUp(), _SQRT_FLOOR_MAX_UINT256, "sqrtUp boundary"); + assertEq(_sqrtFloor(x), _SQRT_FLOOR_MAX_UINT256, "sqrt boundary"); + assertEq(_sqrtUp(x), _SQRT_FLOOR_MAX_UINT256, "sqrtUp boundary"); } } diff --git a/test/0.8.25/formal-model/CbrtModel.t.sol b/test/0.8.25/formal-model/CbrtModel.t.sol new file mode 100644 index 000000000..8b38cca21 --- /dev/null +++ b/test/0.8.25/formal-model/CbrtModel.t.sol @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {CbrtTest} from "../Cbrt.t.sol"; + +/// @dev Runs the CbrtTest fuzz suite against the generated Lean model +/// via `vm.ffi`. Requires the `cbrt-model` binary to be pre-built: +/// cd formal/cbrt/CbrtProof && lake build cbrt-model +contract CbrtModelTest is CbrtTest { + string private constant _BIN = "formal/cbrt/CbrtProof/.lake/build/bin/cbrt-model"; + + function _ffi(string memory fn, uint256 x) private returns (uint256) { + string[] memory args = new string[](3); + args[0] = _BIN; + args[1] = fn; + args[2] = vm.toString(bytes32(x)); + bytes memory result = vm.ffi(args); + return abi.decode(result, (uint256)); + } + + function _cbrtFloor(uint256 x) internal override returns (uint256) { + return _ffi("cbrt_floor", x); + } + + function _cbrtUp(uint256 x) internal override returns (uint256) { + return _ffi("cbrt_up", x); + } +} diff --git a/test/0.8.25/formal-model/Sqrt512Model.t.sol b/test/0.8.25/formal-model/Sqrt512Model.t.sol new file mode 100644 index 000000000..0abb76487 --- /dev/null +++ b/test/0.8.25/formal-model/Sqrt512Model.t.sol @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {SlowMath} from "../SlowMath.sol"; +import {Test} from "@forge-std/Test.sol"; + +/// @dev Fuzz-tests the generated Lean models of 512Math.sqrt and +/// 512Math.osqrtUp against the same correctness properties used in +/// 512Math.t.sol. Calls the compiled Lean evaluator via `vm.ffi`. +/// +/// Requires the `sqrt512-model` binary to be pre-built: +/// cd formal/sqrt/Sqrt512Proof && lake build sqrt512-model +contract Sqrt512ModelTest is Test { + string private constant _BIN = "formal/sqrt/Sqrt512Proof/.lake/build/bin/sqrt512-model"; + + // -- helpers ---------------------------------------------------------- + + function _ffi1(string memory fn, uint256 x_hi, uint256 x_lo) internal returns (uint256) { + string[] memory args = new string[](4); + args[0] = _BIN; + args[1] = fn; + args[2] = vm.toString(bytes32(x_hi)); + args[3] = vm.toString(bytes32(x_lo)); + bytes memory result = vm.ffi(args); + return abi.decode(result, (uint256)); + } + + function _ffi2(string memory fn, uint256 x_hi, uint256 x_lo) internal returns (uint256, uint256) { + string[] memory args = new string[](4); + args[0] = _BIN; + args[1] = fn; + args[2] = vm.toString(bytes32(x_hi)); + args[3] = vm.toString(bytes32(x_lo)); + bytes memory result = vm.ffi(args); + return abi.decode(result, (uint256, uint256)); + } + + /// @dev 512-bit comparison: (a_hi, a_lo) > (b_hi, b_lo) + function _gt512(uint256 aH, uint256 aL, uint256 bH, uint256 bL) internal pure returns (bool) { + return aH > bH || (aH == bH && aL > bL); + } + + /// @dev 512-bit comparison: (a_hi, a_lo) >= (b_hi, b_lo) + function _ge512(uint256 aH, uint256 aL, uint256 bH, uint256 bL) internal pure returns (bool) { + return aH > bH || (aH == bH && aL >= bL); + } + + // -- floor sqrt: model_sqrt512_evm (x_hi > 0) ------------------------ + + function testSqrt512Model(uint256 x_hi, uint256 x_lo) external { + // _sqrt assumes x_hi != 0 (the public sqrt dispatches to 256-bit sqrt otherwise) + vm.assume(x_hi != 0); + + uint256 r = _ffi1("sqrt512", x_hi, x_lo); + + // r^2 <= x + (uint256 r2_lo, uint256 r2_hi) = SlowMath.fullMul(r, r); + assertTrue(!_gt512(r2_hi, r2_lo, x_hi, x_lo), "sqrt too high"); + + // (r+1)^2 > x (unless r == max uint256) + if (r == type(uint256).max) { + assertTrue( + x_hi > 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe + || (x_hi == 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe && x_lo != 0), + "sqrt too low (overflow)" + ); + } else { + uint256 r1 = r + 1; + (r2_lo, r2_hi) = SlowMath.fullMul(r1, r1); + assertTrue(_gt512(r2_hi, r2_lo, x_hi, x_lo), "sqrt too low"); + } + } + + // -- floor sqrt: model_sqrt512_wrapper_evm (full range) --------------- + + function testSqrt512WrapperModel(uint256 x_hi, uint256 x_lo) external { + uint256 r = _ffi1("sqrt512_wrapper", x_hi, x_lo); + + // r^2 <= x + (uint256 r2_lo, uint256 r2_hi) = SlowMath.fullMul(r, r); + assertTrue(!_gt512(r2_hi, r2_lo, x_hi, x_lo), "wrapper sqrt too high"); + + // (r+1)^2 > x + if (r == type(uint256).max) { + assertTrue( + x_hi > 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe + || (x_hi == 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe && x_lo != 0), + "wrapper sqrt too low (overflow)" + ); + } else { + uint256 r1 = r + 1; + (r2_lo, r2_hi) = SlowMath.fullMul(r1, r1); + assertTrue(_gt512(r2_hi, r2_lo, x_hi, x_lo), "wrapper sqrt too low"); + } + } + + // -- ceiling sqrt: model_osqrtUp_evm ---------------------------------- + + function testOsqrtUpModel(uint256 x_hi, uint256 x_lo) external { + (uint256 r_hi, uint256 r_lo) = _ffi2("osqrtUp", x_hi, x_lo); + + // Compute r^2 = (r_hi * 2^256 + r_lo)^2 + // For the ceiling sqrt, r_hi is 0 or 1 (result fits in 257 bits max). + // When r_hi = 0: r^2 = r_lo * r_lo (fits in 512 bits). + // When r_hi = 1: r_lo = 0, r = 2^256, r^2 = 2^512 which overflows. + // This only happens when x = 2^512 - 1 (all ones), but x < 2^512. + + if (r_hi == 0) { + // x <= r_lo^2 + (uint256 r2_lo, uint256 r2_hi) = SlowMath.fullMul(r_lo, r_lo); + assertTrue(_ge512(r2_hi, r2_lo, x_hi, x_lo), "osqrtUp too low"); + + // (r_lo - 1)^2 < x (r_lo is minimal) + if (r_lo > 0) { + (uint256 rm2_lo, uint256 rm2_hi) = SlowMath.fullMul(r_lo - 1, r_lo - 1); + assertTrue(!_ge512(rm2_hi, rm2_lo, x_hi, x_lo), "osqrtUp too high"); + } else { + // r = 0, x must be 0 + assertEq(x_hi, 0, "osqrtUp r=0 but x_hi!=0"); + assertEq(x_lo, 0, "osqrtUp r=0 but x_lo!=0"); + } + } else { + // r_hi = 1, r = 2^256. x <= (2^256)^2 = 2^512 which always holds. + // But x must be > (2^256 - 1)^2 = 2^512 - 2^257 + 1. + assertEq(r_hi, 1, "osqrtUp r_hi > 1"); + assertEq(r_lo, 0, "osqrtUp r_hi=1 but r_lo!=0"); + // (2^256 - 1)^2 < x + uint256 rM = type(uint256).max; + (uint256 rm2_lo, uint256 rm2_hi) = SlowMath.fullMul(rM, rM); + assertTrue(_gt512(x_hi, x_lo, rm2_hi, rm2_lo), "osqrtUp too high (r=2^256)"); + } + } +} diff --git a/test/0.8.25/formal-model/SqrtModel.t.sol b/test/0.8.25/formal-model/SqrtModel.t.sol new file mode 100644 index 000000000..bba9917a2 --- /dev/null +++ b/test/0.8.25/formal-model/SqrtModel.t.sol @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.25; + +import {SqrtTest} from "../Sqrt.t.sol"; + +/// @dev Runs the SqrtTest fuzz suite against the generated Lean model +/// via `vm.ffi`. Requires the `sqrt-model` binary to be pre-built: +/// cd formal/sqrt/SqrtProof && lake build sqrt-model +contract SqrtModelTest is SqrtTest { + string private constant _BIN = "formal/sqrt/SqrtProof/.lake/build/bin/sqrt-model"; + + function _ffi(string memory fn, uint256 x) private returns (uint256) { + string[] memory args = new string[](3); + args[0] = _BIN; + args[1] = fn; + args[2] = vm.toString(bytes32(x)); + bytes memory result = vm.ffi(args); + return abi.decode(result, (uint256)); + } + + function _sqrtFloor(uint256 x) internal override returns (uint256) { + return _ffi("sqrt_floor", x); + } + + function _sqrtUp(uint256 x) internal override returns (uint256) { + return _ffi("sqrt_up", x); + } +}