From 9acdbbd1d83a1f886f1f4e6157781b8741e5cea0 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Thu, 5 Mar 2026 08:34:28 +0000 Subject: [PATCH 1/3] feat(Query): query complexity framework with sorting lower bound MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a framework for proving upper and lower bounds on query complexity of comparison-based algorithms, using `Prog` (free monad over query types) with oracle-parametric evaluation and structural query counting. Results: - Insertion sort: correctness + O(n²) upper bound - Merge sort: correctness + n·⌈log₂ n⌉ upper bound - Lower bound: any correct comparison sort on an infinite type needs ≥ ⌈log₂(n!)⌉ queries (via adversarial pigeonhole on QueryTree depth) Co-Authored-By: Claude Opus 4.6 --- Cslib.lean | 11 + Cslib/Algorithms/Lean/Query/Bounds.lean | 37 +++ Cslib/Algorithms/Lean/Query/Prog.lean | 92 ++++++ Cslib/Algorithms/Lean/Query/QueryTree.lean | 139 +++++++++ .../Lean/Query/Sort/Insertion/Defs.lean | 41 +++ .../Lean/Query/Sort/Insertion/Lemmas.lean | 168 +++++++++++ Cslib/Algorithms/Lean/Query/Sort/IsSort.lean | 37 +++ Cslib/Algorithms/Lean/Query/Sort/LEQuery.lean | 34 +++ .../Lean/Query/Sort/LowerBound.lean | 214 ++++++++++++++ .../Lean/Query/Sort/Merge/Defs.lean | 94 ++++++ .../Lean/Query/Sort/Merge/Lemmas.lean | 271 ++++++++++++++++++ .../Algorithms/Lean/Query/Sort/QueryTree.lean | 107 +++++++ 12 files changed, 1245 insertions(+) create mode 100644 Cslib/Algorithms/Lean/Query/Bounds.lean create mode 100644 Cslib/Algorithms/Lean/Query/Prog.lean create mode 100644 Cslib/Algorithms/Lean/Query/QueryTree.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/IsSort.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/LEQuery.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean create mode 100644 Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean diff --git a/Cslib.lean b/Cslib.lean index a9d5ffc3..dd44a787 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -1,6 +1,17 @@ module -- shake: keep-all public import Cslib.Algorithms.Lean.MergeSort.MergeSort +public import Cslib.Algorithms.Lean.Query.Bounds +public import Cslib.Algorithms.Lean.Query.Prog +public import Cslib.Algorithms.Lean.Query.QueryTree +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Defs +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Lemmas +public import Cslib.Algorithms.Lean.Query.Sort.IsSort +public import Cslib.Algorithms.Lean.Query.Sort.LEQuery +public import Cslib.Algorithms.Lean.Query.Sort.LowerBound +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Defs +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Lemmas +public import Cslib.Algorithms.Lean.Query.Sort.QueryTree public import Cslib.Algorithms.Lean.TimeM public import Cslib.Computability.Automata.Acceptors.Acceptor public import Cslib.Computability.Automata.Acceptors.OmegaAcceptor diff --git a/Cslib/Algorithms/Lean/Query/Bounds.lean b/Cslib/Algorithms/Lean/Query/Bounds.lean new file mode 100644 index 00000000..a9bfe6c8 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Bounds.lean @@ -0,0 +1,37 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Sebastian Graf +-/ +module + +public import Cslib.Algorithms.Lean.Query.Prog + +/-! # Upper and Lower Bounds for Query Complexity + +Definitions of upper and lower bounds on the number of queries a program makes, +quantified over oracles. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- Upper bound: for all oracles, inputs of size ≤ n make at most `bound n` queries. -/ +@[expose] def UpperBound (prog : α → Prog Q β) + (size : α → Nat) (bound : Nat → Nat) : Prop := + ∀ (oracle : {ι : Type} → Q ι → ι) (n : Nat) (x : α), + size x ≤ n → (prog x).queriesOn oracle ≤ bound n + +/-- Lower bound: for every size n, there exists an input and oracle + making the program perform ≥ `bound n` queries. -/ +@[expose] def LowerBound (prog : α → Prog Q β) + (size : α → Nat) (bound : Nat → Nat) : Prop := + ∀ (n : Nat), ∃ (x : α), size x ≤ n ∧ + ∃ (oracle : {ι : Type} → Q ι → ι), bound n ≤ (prog x).queriesOn oracle + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Prog.lean b/Cslib/Algorithms/Lean/Query/Prog.lean new file mode 100644 index 00000000..638096d0 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Prog.lean @@ -0,0 +1,92 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Sebastian Graf, Shreyas Srinivas +-/ +module + +public import Cslib.Foundations.Control.Monad.Free + +/-! # Prog: Programs as Free Monads over Query Types + +`Prog Q α` is an alias for `FreeM Q α`, representing a program that makes queries of type `Q` +and returns a result of type `α`. A query type `Q : Type → Type` maps each query to its +response type. + +The key operations are: +- `Prog.eval oracle p`: evaluate `p` by answering each query using `oracle` +- `Prog.queriesOn oracle p`: count the queries along the oracle-determined path + +Because the oracle is supplied *after* the program produces its query plan (the `Prog` tree), +a sound implementation of `prog` has no way to "guess" what the oracle would respond. +This is the foundation of the anti-cheating guarantee for both upper and lower bounds. + +This provides an alternative to the `TimeM`-based cost analysis in +`Cslib.Algorithms.Lean.MergeSort`: here query counting is structural (derived from the +`Prog` tree) rather than annotation-based. +-/ + +open Cslib + +public section + +namespace Cslib.Query + +/-- A program that makes queries of type `Q` and returns a result of type `α`. + This is `FreeM Q α`, the free monad over the query type. -/ +abbrev Prog (Q : Type → Type) (α : Type) := FreeM Q α + +namespace Prog + +variable {Q : Type → Type} {α β : Type} + +/-- Evaluate a program by answering each query using `oracle`. -/ +@[expose] def eval (oracle : {ι : Type} → Q ι → ι) : Prog Q α → α + | .pure a => a + | .liftBind op cont => eval oracle (cont (oracle op)) + +/-- Count the number of queries along the path determined by `oracle`. -/ +@[expose] def queriesOn (oracle : {ι : Type} → Q ι → ι) : Prog Q α → Nat + | .pure _ => 0 + | .liftBind op cont => 1 + queriesOn oracle (cont (oracle op)) + +-- Simp lemmas for eval + +@[simp] theorem eval_pure (oracle : {ι : Type} → Q ι → ι) (a : α) : + eval oracle (.pure a : Prog Q α) = a := rfl + +@[simp] theorem eval_liftBind (oracle : {ι : Type} → Q ι → ι) + {ι : Type} (op : Q ι) (cont : ι → Prog Q α) : + eval oracle (.liftBind op cont) = eval oracle (cont (oracle op)) := rfl + +@[simp] theorem eval_bind (oracle : {ι : Type} → Q ι → ι) + (t : Prog Q α) (f : α → Prog Q β) : + eval oracle (t.bind f) = eval oracle (f (eval oracle t)) := by + induction t with + | pure a => rfl + | liftBind op cont ih => exact ih (oracle op) + +-- Simp lemmas for queriesOn + +@[simp] theorem queriesOn_pure (oracle : {ι : Type} → Q ι → ι) (a : α) : + queriesOn oracle (.pure a : Prog Q α) = 0 := rfl + +@[simp] theorem queriesOn_liftBind (oracle : {ι : Type} → Q ι → ι) + {ι : Type} (op : Q ι) (cont : ι → Prog Q α) : + queriesOn oracle (.liftBind op cont) = 1 + queriesOn oracle (cont (oracle op)) := rfl + +@[simp] theorem queriesOn_bind (oracle : {ι : Type} → Q ι → ι) + (t : Prog Q α) (f : α → Prog Q β) : + queriesOn oracle (t.bind f) = + queriesOn oracle t + queriesOn oracle (f (eval oracle t)) := by + induction t with + | pure a => simp [FreeM.bind] + | liftBind op cont ih => + simp only [FreeM.bind, queriesOn_liftBind, eval_liftBind, ih (oracle op)] + omega + +end Prog + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/QueryTree.lean b/Cslib/Algorithms/Lean/Query/QueryTree.lean new file mode 100644 index 00000000..06475d17 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/QueryTree.lean @@ -0,0 +1,139 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Sebastian Graf, Shreyas Srinivas +-/ +module + +public import Mathlib.Data.Nat.Log +public import Mathlib.Data.Fintype.Card + +/-! # QueryTree: Decision Trees for Query Complexity Lower Bounds + +`QueryTree Q R α` is a free monad specialized to a single query type: queries take +input `Q` and return `R`, with final results of type `α`. It reifies an algorithm's +query pattern as an explicit decision tree. + +The key advantage over `Prog`/`FreeM` for lower bound proofs is that `R` is a fixed type +parameter (not existentially quantified per query), making structural induction with +pigeonhole arguments straightforward. + +## Main Definitions + +- `QueryTree Q R α` — the decision tree type +- `QueryTree.ask` — the canonical single-query tree +- `QueryTree.eval` — evaluate with a specific oracle +- `QueryTree.queriesOn` — count queries along an oracle-determined path +-/ + +public section + +namespace Cslib.Query + +/-- A decision tree over queries of type `Q → R`, with results of type `α`. + +This is the free monad specialized to a single fixed-type operation, used to reify +algorithms as explicit trees for query complexity lower bounds. -/ +inductive QueryTree (Q : Type) (R : Type) (α : Type) where + /-- A completed computation returning value `a`. -/ + | pure (a : α) : QueryTree Q R α + /-- A query node: asks query `q`, then continues based on the response. -/ + | query (q : Q) (cont : R → QueryTree Q R α) : QueryTree Q R α + +namespace QueryTree + +variable {Q R α β γ : Type} + +/-- Lift a single query into the tree. -/ +@[expose] def ask (q : Q) : QueryTree Q R R := .query q .pure + +/-- Monadic bind for query trees. -/ +@[expose] protected def bind : QueryTree Q R α → (α → QueryTree Q R β) → QueryTree Q R β + | .pure a, f => f a + | .query q cont, f => .query q (fun r => (cont r).bind f) + +/-- Functorial map for query trees. -/ +@[expose] protected def map (f : α → β) : QueryTree Q R α → QueryTree Q R β + | .pure a => .pure (f a) + | .query q cont => .query q (fun r => (cont r).map f) + +protected theorem bind_pure : ∀ (x : QueryTree Q R α), x.bind .pure = x + | .pure _ => rfl + | .query _ cont => by simp [QueryTree.bind, QueryTree.bind_pure] + +protected theorem bind_assoc : + ∀ (x : QueryTree Q R α) (f : α → QueryTree Q R β) (g : β → QueryTree Q R γ), + (x.bind f).bind g = x.bind (fun a => (f a).bind g) + | .pure _, _, _ => rfl + | .query _ cont, f, g => by simp [QueryTree.bind, QueryTree.bind_assoc] + +protected theorem bind_pure_comp (f : α → β) : + ∀ (x : QueryTree Q R α), x.bind (.pure ∘ f) = x.map f + | .pure _ => rfl + | .query _ cont => by simp [QueryTree.bind, QueryTree.map, QueryTree.bind_pure_comp] + +protected theorem id_map : ∀ (x : QueryTree Q R α), x.map id = x + | .pure _ => rfl + | .query _ cont => by simp [QueryTree.map, QueryTree.id_map] + +instance : Monad (QueryTree Q R) where + pure := .pure + bind := .bind + +instance : LawfulMonad (QueryTree Q R) := LawfulMonad.mk' + (bind_pure_comp := fun _ _ => rfl) + (id_map := QueryTree.bind_pure) + (pure_bind := fun _ _ => rfl) + (bind_assoc := QueryTree.bind_assoc) + +-- Core operations + +/-- Evaluate a query tree with a specific oracle, returning the final result. -/ +@[expose] def eval (oracle : Q → R) : QueryTree Q R α → α + | .pure a => a + | .query q cont => eval oracle (cont (oracle q)) + +/-- Count the number of queries along the path determined by `oracle`. -/ +@[expose] def queriesOn (oracle : Q → R) : QueryTree Q R α → Nat + | .pure _ => 0 + | .query q cont => 1 + queriesOn oracle (cont (oracle q)) + +-- Simp lemmas + +@[simp] theorem eval_pure' (oracle : Q → R) (a : α) : + (QueryTree.pure a : QueryTree Q R α).eval oracle = a := rfl + +@[simp] theorem eval_query (oracle : Q → R) (q : Q) (cont : R → QueryTree Q R α) : + (QueryTree.query q cont).eval oracle = (cont (oracle q)).eval oracle := rfl + +@[simp] theorem eval_bind (oracle : Q → R) (t : QueryTree Q R α) (f : α → QueryTree Q R β) : + (t.bind f).eval oracle = (f (t.eval oracle)).eval oracle := by + induction t with + | pure a => rfl + | query q cont ih => exact ih (oracle q) + +@[simp] theorem queriesOn_pure' (oracle : Q → R) (a : α) : + (QueryTree.pure a : QueryTree Q R α).queriesOn oracle = 0 := rfl + +@[simp] theorem queriesOn_query (oracle : Q → R) (q : Q) (cont : R → QueryTree Q R α) : + (QueryTree.query q cont).queriesOn oracle = 1 + (cont (oracle q)).queriesOn oracle := rfl + +/-- Queries of `t.bind f` = queries of `t` + queries of the continuation. -/ +@[simp] theorem queriesOn_bind (oracle : Q → R) (t : QueryTree Q R α) (f : α → QueryTree Q R β) : + (t.bind f).queriesOn oracle = + t.queriesOn oracle + (f (t.eval oracle)).queriesOn oracle := by + induction t with + | pure a => simp [QueryTree.bind, queriesOn, eval] + | query q cont ih => simp only [QueryTree.bind, queriesOn_query, eval_query, ih (oracle q)]; omega + +@[simp] theorem queriesOn_ask (oracle : Q → R) (q : Q) : + (ask q : QueryTree Q R R).queriesOn oracle = 1 := rfl + +@[simp] theorem eval_ask (oracle : Q → R) (q : Q) : + (ask q : QueryTree Q R R).eval oracle = oracle q := rfl + +end QueryTree + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean new file mode 100644 index 00000000..e32fd638 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Defs.lean @@ -0,0 +1,41 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Sort.LEQuery + +/-! # Insertion Sort as a Query Program + +Insertion sort implemented as a `Prog (LEQuery α)`, making all comparison queries explicit. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- Insert `x` into a sorted list using comparison queries. -/ +@[expose] def orderedInsert (x : α) : List α → Prog (LEQuery α) (List α) + | [] => pure [x] + | y :: ys => do + let le ← LEQuery.ask x y + if le then + pure (x :: y :: ys) + else do + let rest ← orderedInsert x ys + pure (y :: rest) + +/-- Sort a list using insertion sort with comparison queries. -/ +@[expose] def insertionSort : List α → Prog (LEQuery α) (List α) + | [] => pure [] + | x :: xs => do + let sorted ← insertionSort xs + orderedInsert x sorted + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean new file mode 100644 index 00000000..2dbcbce9 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Insertion/Lemmas.lean @@ -0,0 +1,168 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Bounds +public import Cslib.Algorithms.Lean.Query.Sort.IsSort +public import Cslib.Algorithms.Lean.Query.Sort.Insertion.Defs +import Mathlib.Data.List.Sort +import Mathlib.Tactic.Ring +public import Mathlib.Algebra.Group.Defs + +/-! # Insertion Sort: Correctness and Upper Bound + +Proofs that `insertionSort` is a correct comparison sort and uses at most `n²` queries. +All proofs are by plain equational reasoning on `Prog.eval` and `Prog.queriesOn`. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +variable {α : Type} + +-- ## Evaluation simp lemmas for orderedInsert + +@[simp] theorem eval_orderedInsert_nil (oracle : {ι : Type} → LEQuery α ι → ι) (x : α) : + (orderedInsert x ([] : List α)).eval oracle = [x] := by + simp [orderedInsert] + +@[simp] theorem eval_orderedInsert_cons (oracle : {ι : Type} → LEQuery α ι → ι) (x y : α) + (ys : List α) : + (orderedInsert x (y :: ys)).eval oracle = + if oracle (.le x y) then x :: y :: ys + else y :: (orderedInsert x ys).eval oracle := by + simp [orderedInsert, LEQuery.ask] + split <;> simp_all + +-- ## Evaluation simp lemmas for insertionSort + +@[simp] theorem eval_insertionSort_nil (oracle : {ι : Type} → LEQuery α ι → ι) : + (insertionSort (α := α) []).eval oracle = [] := by + simp [insertionSort] + +@[simp] theorem eval_insertionSort_cons (oracle : {ι : Type} → LEQuery α ι → ι) + (x : α) (xs : List α) : + (insertionSort (x :: xs)).eval oracle = + (orderedInsert x ((insertionSort xs).eval oracle)).eval oracle := by + simp [insertionSort] + +-- ## Permutation proofs + +theorem orderedInsert_perm (oracle : {ι : Type} → LEQuery α ι → ι) (x : α) (xs : List α) : + ((orderedInsert x xs).eval oracle).Perm (x :: xs) := by + induction xs with + | nil => simp + | cons y ys ih => + simp only [eval_orderedInsert_cons] + split + · exact List.Perm.refl _ + · exact (List.Perm.cons _ ih).trans (List.Perm.swap _ _ _) + +theorem insertionSort_perm (oracle : {ι : Type} → LEQuery α ι → ι) (xs : List α) : + ((insertionSort xs).eval oracle).Perm xs := by + induction xs with + | nil => simp + | cons x xs ih => + simp only [eval_insertionSort_cons] + exact (orderedInsert_perm oracle x _).trans (List.Perm.cons _ ih) + +-- ## Sortedness proofs + +theorem orderedInsert_sorted + (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + (oracle : {ι : Type} → LEQuery α ι → ι) + (horacle : ∀ a b, oracle (.le a b) = decide (r a b)) + (x : α) (xs : List α) (hxs : xs.Pairwise r) : + ((orderedInsert x xs).eval oracle).Pairwise r := by + induction xs with + | nil => simp + | cons y ys ih => + simp only [eval_orderedInsert_cons, horacle] + split + next h => + have hle : r x y := by simpa [decide_eq_true_eq] using h + exact List.pairwise_cons.mpr ⟨fun z hz => + match List.mem_cons.mp hz with + | .inl h => h ▸ hle + | .inr h => _root_.trans hle (List.rel_of_pairwise_cons hxs h), hxs⟩ + next h => + have hle : ¬ r x y := by simpa [decide_eq_true_eq] using h + have hyx : r y x := (Std.Total.total y x).resolve_right hle + have ih' := ih hxs.of_cons + have hperm := orderedInsert_perm oracle x ys + exact List.pairwise_cons.mpr ⟨fun z hz => + match List.mem_cons.mp (hperm.mem_iff.mp hz) with + | .inl h => h ▸ hyx + | .inr h => List.rel_of_pairwise_cons hxs h, ih'⟩ + +theorem insertionSort_sorted + (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + (oracle : {ι : Type} → LEQuery α ι → ι) + (horacle : ∀ a b, oracle (.le a b) = decide (r a b)) + (xs : List α) : + ((insertionSort xs).eval oracle).Pairwise r := by + induction xs with + | nil => simp + | cons x xs ih => + simp only [eval_insertionSort_cons] + exact orderedInsert_sorted r oracle horacle x _ ih + +-- ## Query count proofs + +theorem orderedInsert_queriesOn_le (oracle : {ι : Type} → LEQuery α ι → ι) + (x : α) (xs : List α) : + (orderedInsert x xs).queriesOn oracle ≤ xs.length := by + induction xs with + | nil => simp [orderedInsert] + | cons y ys ih => + unfold orderedInsert LEQuery.ask + simp + split + · simp_all + · simp_all; omega + +theorem insertionSort_queriesOn_le (oracle : {ι : Type} → LEQuery α ι → ι) + (xs : List α) : + (insertionSort xs).queriesOn oracle ≤ xs.length ^ 2 := by + induction xs with + | nil => simp [insertionSort] + | cons x xs ih => + have hq : (insertionSort (x :: xs)).queriesOn oracle = + (insertionSort xs).queriesOn oracle + + (orderedInsert x ((insertionSort xs).eval oracle)).queriesOn oracle := by + simp [insertionSort] + rw [hq] + have hlen : ((insertionSort xs).eval oracle).length = xs.length := + (insertionSort_perm oracle xs).length_eq + have hord := orderedInsert_queriesOn_le oracle x ((insertionSort xs).eval oracle) + rw [hlen] at hord + have h1 := Nat.add_le_add ih hord + have hpow : xs.length ^ 2 + xs.length ≤ (xs.length + 1) ^ 2 := by + have : (xs.length + 1) ^ 2 = xs.length ^ 2 + 2 * xs.length + 1 := by ring + omega + simp only [List.length_cons] + exact Nat.le_trans h1 hpow + +-- ## UpperBound and IsSort instances + +public theorem insertionSort_upperBound : + UpperBound (insertionSort (α := α)) List.length (· ^ 2) := by + intro oracle n x hle + exact Nat.le_trans (insertionSort_queriesOn_le oracle x) + (Nat.pow_le_pow_left hle 2) + +public theorem insertionSort_isSort : IsSort (insertionSort (α := α)) where + perm xs oracle := insertionSort_perm oracle xs + sorted := by + intro xs oracle r _ _ _ horacle + exact insertionSort_sorted r oracle horacle xs + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/IsSort.lean b/Cslib/Algorithms/Lean/Query/Sort/IsSort.lean new file mode 100644 index 00000000..b687d534 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/IsSort.lean @@ -0,0 +1,37 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Sort.LEQuery + +/-! # IsSort: Specification for Comparison Sorts + +`IsSort sort` asserts that `sort` is a correct comparison sort when viewed as a `Prog` +over `LEQuery α`. Correctness means: for any oracle, the result is a permutation of the +input; and for any oracle implementing a total order, the result is sorted. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- A `Prog`-based function is a correct comparison sort if it always produces a permutation + of its input, and produces a sorted list when the oracle implements a total order. -/ +structure IsSort (sort : List α → Prog (LEQuery α) (List α)) : Prop where + /-- The sort produces a permutation of its input, for any oracle. -/ + perm : ∀ (xs : List α) (oracle : {ι : Type} → LEQuery α ι → ι), + ((sort xs).eval oracle).Perm xs + /-- The sort produces a sorted list, when the oracle implements a total order. -/ + sorted : ∀ (xs : List α) (oracle : {ι : Type} → LEQuery α ι → ι) + (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + (_ : ∀ a b, oracle (.le a b) = decide (r a b)), + ((sort xs).eval oracle).Pairwise r + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/LEQuery.lean b/Cslib/Algorithms/Lean/Query/Sort/LEQuery.lean new file mode 100644 index 00000000..b92766c0 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/LEQuery.lean @@ -0,0 +1,34 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Sebastian Graf +-/ +module + +public import Cslib.Algorithms.Lean.Query.Prog + +/-! # LEQuery: Comparison Queries for Sorting + +`LEQuery α` is the query type for comparison-based sorting algorithms. +A query `LEQuery.le a b` asks whether `a ≤ b` and returns a `Bool`. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- Comparison query: asks whether `a ≤ b`, returning a `Bool`. -/ +inductive LEQuery (α : Type) : Type → Type where + | le (a b : α) : LEQuery α Bool + +@[expose] def LEQuery.ask (a b : α) : Prog (LEQuery α) Bool := + .liftBind (.le a b) .pure + +@[simp] theorem LEQuery.eval_ask (oracle : {ι : Type} → LEQuery α ι → ι) (a b : α) : + Prog.eval oracle (LEQuery.ask a b) = oracle (.le a b) := rfl + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean b/Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean new file mode 100644 index 00000000..0c6cee6e --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/LowerBound.lean @@ -0,0 +1,214 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Eric Wieser +-/ +module + +public import Cslib.Algorithms.Lean.Query.Bounds +public import Cslib.Algorithms.Lean.Query.Sort.IsSort +public import Cslib.Algorithms.Lean.Query.Sort.QueryTree +public import Mathlib.Data.List.Sort +public import Mathlib.Data.Nat.Factorial.Basic +public import Mathlib.Data.Fintype.Perm +public import Mathlib.Data.List.FinRange +public import Mathlib.SetTheory.Cardinal.Order + +/-! # Comparison Sorting Lower Bound + +`IsSort.lowerBound_infinite`: any correct comparison sort on an infinite type +has query complexity at least `⌈log₂(n!)⌉` for every input size `n`. + +The proof constructs `n!` distinct total orders on `α` (one per permutation of `n` +embedded elements), shows they produce distinct sorted outputs, and applies +`QueryTree.exists_queriesOn_ge_clog`. + +## Prog-to-QueryTree Bridge + +Since `Prog (LEQuery α) β` uses an existentially quantified response type per query (via +`FreeM.liftBind`), while `QueryTree` has a fixed response type `R`, we provide a conversion +`Prog.toQueryTree` that exploits the fact that `LEQuery α` only has one constructor returning +`Bool`. This lets us apply the combinatorial depth lemma on `QueryTree` and transfer results +back to `Prog`. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +-- ## Prog-to-QueryTree bridge for LEQuery + +/-- Convert a `Prog`-oracle to a `QueryTree`-oracle for `LEQuery`. -/ +@[expose] def toQTOracle (oracle : {ι : Type} → LEQuery α ι → ι) : (α × α) → Bool := + fun (a, b) => oracle (.le a b) + +/-- Convert a `QueryTree`-oracle to a `Prog`-oracle for `LEQuery`. -/ +@[expose] def fromQTOracle (f : (α × α) → Bool) : {ι : Type} → LEQuery α ι → ι + | _, .le a b => f (a, b) + +@[simp] theorem fromQTOracle_le (f : (α × α) → Bool) (a b : α) : + fromQTOracle f (.le a b) = f (a, b) := rfl + +@[simp] theorem toQTOracle_fromQTOracle (f : (α × α) → Bool) : + toQTOracle (fromQTOracle f) = f := rfl + +/-- Convert a `Prog (LEQuery α)` program to a `QueryTree (α × α) Bool` decision tree. -/ +@[expose] def Prog.toQueryTree : Prog (LEQuery α) β → QueryTree (α × α) Bool β + | .pure a => .pure a + | .liftBind (.le a b) cont => .query (a, b) (fun r => Prog.toQueryTree (cont r)) + +/-- Evaluation is preserved by the Prog-to-QueryTree conversion. -/ +@[simp] theorem Prog.toQueryTree_eval (oracle : {ι : Type} → LEQuery α ι → ι) : + (p : Prog (LEQuery α) β) → + p.toQueryTree.eval (toQTOracle oracle) = p.eval oracle + | .pure _ => rfl + | .liftBind (.le a b) cont => by + simp only [toQueryTree, QueryTree.eval_query, Prog.eval, toQTOracle] + exact toQueryTree_eval oracle (cont (oracle (.le a b))) + +/-- Query count is preserved by the Prog-to-QueryTree conversion. -/ +@[simp] theorem Prog.toQueryTree_queriesOn (oracle : {ι : Type} → LEQuery α ι → ι) : + (p : Prog (LEQuery α) β) → + p.toQueryTree.queriesOn (toQTOracle oracle) = p.queriesOn oracle + | .pure _ => rfl + | .liftBind (.le a b) cont => by + simp only [toQueryTree, QueryTree.queriesOn_query, Prog.queriesOn, toQTOracle] + exact congrArg (1 + ·) (toQueryTree_queriesOn oracle (cont (oracle (.le a b)))) + +-- ## infinitePermOrder: constructing n! distinct total orders + +open Classical in +/-- A total order on an infinite type `α` that orders `n` embedded elements + (via `Infinite.natEmbedding`) according to `σ⁻¹`, with embedded elements + preceding all others, and a well-ordering among non-embedded elements. -/ +private noncomputable def infinitePermOrder [Infinite α] (n : Nat) + (σ : Equiv.Perm (Fin n)) (a b : α) : Prop := + if ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a then + if hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b then + σ.symm ha.choose ≤ σ.symm hb.choose + else True + else + if _ : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b then False + else @LE.le α (IsWellOrder.linearOrder (α := α) WellOrderingRel).toLE a b + +private noncomputable instance [Infinite α] : + DecidableRel (infinitePermOrder (α := α) n σ) := Classical.decRel _ + +private theorem infinitePermOrder.choose_eq [Infinite α] {i : Fin n} + (h : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = (Infinite.natEmbedding α) i.val) : + h.choose = i := by + grind + +private instance [Infinite α] : + IsTrans α (infinitePermOrder (α := α) n σ) where + trans a b c hab hbc := by + letI : LinearOrder α := IsWellOrder.linearOrder WellOrderingRel + unfold infinitePermOrder at * + by_cases ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a <;> + by_cases hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b <;> + by_cases hc : ∃ k : Fin n, (Infinite.natEmbedding α) k.val = c <;> + grind + +private instance [Infinite α] : + Std.Total (infinitePermOrder (α := α) n σ) where + total a b := by + letI : LinearOrder α := IsWellOrder.linearOrder WellOrderingRel + unfold infinitePermOrder + by_cases ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a + · simp only [dite_else_true] + grind + · simp_all only [reduceDIte, dite_eq_ite, if_true_left] + grind + +attribute [local grind inj] Equiv.injective in +private instance [Infinite α] : + Std.Antisymm (infinitePermOrder (α := α) n σ) where + antisymm a b hab hba := by + letI : LinearOrder α := IsWellOrder.linearOrder WellOrderingRel + simp only [infinitePermOrder] at hab hba + by_cases ha : ∃ i : Fin n, (Infinite.natEmbedding α) i.val = a <;> + by_cases hb : ∃ j : Fin n, (Infinite.natEmbedding α) j.val = b <;> + simp_all only [↓reduceDIte, not_exists] <;> grind + +/-- `infinitePermOrder` restricted to embedded values matches `σ⁻¹(·) ≤ σ⁻¹(·)`. -/ +@[grind =] +private theorem infinitePermOrder_on_embedded [Infinite α] {i j : Fin n} : + infinitePermOrder (α := α) n σ ((Infinite.natEmbedding α) i.val) + ((Infinite.natEmbedding α) j.val) ↔ σ.symm i ≤ σ.symm j := by + have hi : ∃ k : Fin n, (Infinite.natEmbedding α) k.val = (Infinite.natEmbedding α) i.val := + ⟨i, rfl⟩ + have hj : ∃ k : Fin n, (Infinite.natEmbedding α) k.val = (Infinite.natEmbedding α) j.val := + ⟨j, rfl⟩ + grind [infinitePermOrder] + +/-- `map (ι ∘ Fin.val ∘ σ) (finRange n)` is pairwise sorted by `infinitePermOrder n σ`. -/ +private theorem pairwise_map_infinitePermOrder [Infinite α] (σ : Equiv.Perm (Fin n)) : + List.Pairwise (infinitePermOrder (α := α) n σ) + ((List.finRange n).map (fun i => (Infinite.natEmbedding α) (σ i).val)) := by + rw [List.pairwise_map] + exact (List.pairwise_le_finRange n).imp fun hab => by grind + +/-- `map (ι ∘ Fin.val ∘ σ) (finRange n)` is a permutation of `map (ι ∘ Fin.val) (finRange n)`. -/ +private theorem map_perm_of_infinite_embedding [Infinite α] (σ : Equiv.Perm (Fin n)) : + ((List.finRange n).map (fun i => (Infinite.natEmbedding α) (σ i).val)).Perm + ((List.finRange n).map (fun i => (Infinite.natEmbedding α) i.val)) := by + rw [show (fun i => (Infinite.natEmbedding α) (σ i).val) = + (fun i => (Infinite.natEmbedding α) i.val) ∘ σ from rfl] + grind [Equiv.Perm.map_finRange_perm] + +/-- Different permutations give different `map (ι ∘ Fin.val ∘ σ) (finRange n)`. -/ +private theorem map_infinite_embedding_injective [Infinite α] : + Function.Injective (fun σ : Equiv.Perm (Fin n) => + (List.finRange n).map (fun i => (Infinite.natEmbedding α) (σ i).val)) := by + intro σ τ h + exact Equiv.ext fun i => by + have := List.map_inj_left.mp h i (List.mem_finRange i) + grind + +-- ## Main theorem + +/-- Any correct comparison sort on an infinite type has query complexity at least `⌈log₂(n!)⌉` + for every input size `n`. -/ +theorem IsSort.lowerBound_infinite [Infinite α] + {sort : List α → Prog (LEQuery α) (List α)} + (h : IsSort sort) : + LowerBound sort List.length (fun n => Nat.clog 2 (Nat.factorial n)) := by + intro n + set ι := Infinite.natEmbedding α + refine ⟨(List.finRange n).map (fun i => ι i.val), by simp, ?_⟩ + set xs := (List.finRange n).map (fun i => ι i.val) + set tree := (sort xs).toQueryTree + have hcard : Fintype.card (Equiv.Perm (Fin n)) = Nat.factorial n := by + rw [Fintype.card_perm, Fintype.card_fin] + let e := Fintype.equivFinOfCardEq hcard + -- Define Prog-level oracles, then derive QueryTree oracles from them + let progOracles : Fin (Nat.factorial n) → ({ι : Type} → LEQuery α ι → ι) := + fun i => fromQTOracle (fun p => decide (infinitePermOrder n (e.symm i) p.1 p.2)) + let qtOracles : Fin (Nat.factorial n) → ((α × α) → Bool) := + fun i => toQTOracle (progOracles i) + -- Each oracle produces a unique sorted output + have h_inj : Function.Injective (fun i => tree.eval (qtOracles i)) := by + intro i j h_eval + suffices key : ∀ i, (sort xs).eval (progOracles i) = + (List.finRange n).map (fun k => ι ((e.symm i) k).val) by + simp only [tree, qtOracles, Prog.toQueryTree_eval] at h_eval + rw [key, key] at h_eval + exact e.symm.injective (map_infinite_embedding_injective h_eval) + intro i + have h_perm := h.perm xs (progOracles i) + have h_sorted := h.sorted xs (progOracles i) + (infinitePermOrder (α := α) n (e.symm i)) + (fun a b => by simp [progOracles]) + exact h_perm.trans (map_perm_of_infinite_embedding (e.symm i)).symm |>.eq_of_pairwise' + h_sorted (pairwise_map_infinitePermOrder (e.symm i)) + -- Apply the depth lemma + obtain ⟨i, hi⟩ := QueryTree.exists_queriesOn_ge_clog tree qtOracles (Nat.factorial_pos n) h_inj + refine ⟨progOracles i, ?_⟩ + simp only [tree, qtOracles, Prog.toQueryTree_queriesOn] at hi + exact hi + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean b/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean new file mode 100644 index 00000000..15ebf19b --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Merge/Defs.lean @@ -0,0 +1,94 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sorrachai Yingchareonthawornchai, Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Sort.LEQuery + +/-! # Merge Sort as a Query Program + +Merge sort implemented as a `Prog (LEQuery α)`, making all comparison queries explicit. +Uses an alternating split (odds/evens) to avoid needing `List.length` in the termination +argument. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- Split a list into two halves by alternating elements. -/ +@[expose] def split : List α → List α × List α + | [] => ([], []) + | [x] => ([x], []) + | x :: y :: zs => + let (l, r) := split zs + (x :: l, y :: r) + +@[simp] theorem split_nil : split (α := α) [] = ([], []) := rfl +@[simp] theorem split_singleton (x : α) : split [x] = ([x], []) := rfl +@[simp] theorem split_cons_cons (x y : α) (zs : List α) : + split (x :: y :: zs) = ((split zs).1 |>.cons x, (split zs).2 |>.cons y) := by + simp [split] + +theorem split_fst_length_eq : ∀ (xs : List α), + (split xs).1.length = (xs.length + 1) / 2 + | [] => by simp [split] + | [_] => by simp [split] + | _ :: _ :: zs => by + simp only [split_cons_cons, List.length_cons] + have := split_fst_length_eq zs + omega + +theorem split_snd_length_eq : ∀ (xs : List α), + (split xs).2.length = xs.length / 2 + | [] => by simp [split] + | [_] => by simp [split] + | _ :: _ :: zs => by + simp only [split_cons_cons, List.length_cons] + have := split_snd_length_eq zs + omega + +theorem split_fst_length_lt (x y : α) (zs : List α) : + (split (x :: y :: zs)).1.length < (x :: y :: zs).length := by + simp only [split_fst_length_eq, List.length_cons]; omega + +theorem split_snd_length_lt (x y : α) (zs : List α) : + (split (x :: y :: zs)).2.length < (x :: y :: zs).length := by + simp only [split_snd_length_eq, List.length_cons]; omega + +/-- Merge two sorted lists using comparison queries. -/ +@[expose] def merge (xs ys : List α) : Prog (LEQuery α) (List α) := + match xs, ys with + | [], ys => pure ys + | xs, [] => pure xs + | x :: xs', y :: ys' => do + let le ← LEQuery.ask x y + if le then do + let rest ← merge xs' (y :: ys') + pure (x :: rest) + else do + let rest ← merge (x :: xs') ys' + pure (y :: rest) +termination_by xs.length + ys.length + +/-- Sort a list using merge sort with comparison queries. -/ +@[expose] def mergeSort (xs : List α) : Prog (LEQuery α) (List α) := + match xs with + | [] => pure [] + | [x] => pure [x] + | x :: y :: zs => do + let sl ← mergeSort (split (x :: y :: zs)).1 + let sr ← mergeSort (split (x :: y :: zs)).2 + merge sl sr +termination_by xs.length +decreasing_by + · exact split_fst_length_lt x y zs + · exact split_snd_length_lt x y zs + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean new file mode 100644 index 00000000..8ea9bcfe --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/Merge/Lemmas.lean @@ -0,0 +1,271 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sorrachai Yingchareonthawornchai, Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Bounds +public import Cslib.Algorithms.Lean.Query.Sort.IsSort +public import Cslib.Algorithms.Lean.Query.Sort.Merge.Defs +import Mathlib.Data.List.Sort +public import Mathlib.Algebra.Group.Defs +public import Mathlib.Data.Nat.Log + +/-! # Merge Sort: Correctness and Upper Bound + +Proofs that `mergeSort` is a correct comparison sort and uses at most `n * ⌈log₂ n⌉` queries. +All proofs are by plain equational reasoning on `Prog.eval` and `Prog.queriesOn`. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +variable {α : Type} + +-- ## Split lemmas + +theorem split_perm : ∀ (xs : List α), + ((split xs).1 ++ (split xs).2).Perm xs + | [] => List.Perm.refl _ + | [_] => List.Perm.refl _ + | x :: y :: zs => by + simp only [split_cons_cons] + show ((x :: (split zs).1) ++ (y :: (split zs).2)).Perm (x :: y :: zs) + rw [List.cons_append] + refine List.Perm.cons _ ?_ + -- goal: ((split zs).1 ++ y :: (split zs).2).Perm (y :: zs) + exact (List.perm_middle).trans (List.Perm.cons _ (split_perm zs)) + +theorem split_lengths_add (xs : List α) : + (split xs).1.length + (split xs).2.length = xs.length := by + simp [split_fst_length_eq, split_snd_length_eq]; omega + +-- ## Evaluation simp lemmas for merge + +@[simp] theorem eval_merge_nil_left (oracle : {ι : Type} → LEQuery α ι → ι) (ys : List α) : + (merge ([] : List α) ys).eval oracle = ys := by + simp [merge] + +@[simp] theorem eval_merge_nil_right (oracle : {ι : Type} → LEQuery α ι → ι) (xs : List α) : + (merge xs ([] : List α)).eval oracle = xs := by + cases xs <;> simp [merge] + +@[simp] theorem eval_merge_cons_cons (oracle : {ι : Type} → LEQuery α ι → ι) + (x : α) (xs' : List α) (y : α) (ys' : List α) : + (merge (x :: xs') (y :: ys')).eval oracle = + if oracle (.le x y) + then x :: (merge xs' (y :: ys')).eval oracle + else y :: (merge (x :: xs') ys').eval oracle := by + simp [merge, LEQuery.ask] + split <;> simp_all + +-- ## Evaluation simp lemmas for mergeSort + +@[simp] theorem eval_mergeSort_nil (oracle : {ι : Type} → LEQuery α ι → ι) : + (mergeSort (α := α) []).eval oracle = [] := by + simp [mergeSort] + +@[simp] theorem eval_mergeSort_singleton (oracle : {ι : Type} → LEQuery α ι → ι) (x : α) : + (mergeSort [x]).eval oracle = [x] := by + simp [mergeSort] + +@[simp] theorem eval_mergeSort_cons_cons (oracle : {ι : Type} → LEQuery α ι → ι) + (x y : α) (zs : List α) : + (mergeSort (x :: y :: zs)).eval oracle = + (merge + ((mergeSort (split (x :: y :: zs)).1).eval oracle) + ((mergeSort (split (x :: y :: zs)).2).eval oracle)).eval oracle := by + simp [mergeSort] + +-- ## Permutation proofs + +theorem merge_perm (oracle : {ι : Type} → LEQuery α ι → ι) (xs ys : List α) : + ((merge xs ys).eval oracle).Perm (xs ++ ys) := by + induction xs, ys using merge.induct (α := α) with + | case1 ys => simp + | case2 xs => simp + | case3 x xs' y ys' ih_true ih_false => + simp only [eval_merge_cons_cons] + split + · exact List.Perm.cons _ ih_true + · -- goal: (y :: (merge (x :: xs') ys').eval oracle).Perm (x :: xs' ++ y :: ys') + -- ih: ((merge (x :: xs') ys').eval oracle).Perm ((x :: xs') ++ ys') + exact (List.Perm.cons _ ih_false).trans List.perm_middle.symm + +theorem mergeSort_perm (oracle : {ι : Type} → LEQuery α ι → ι) (xs : List α) : + ((mergeSort xs).eval oracle).Perm xs := by + induction xs using mergeSort.induct (α := α) with + | case1 => simp + | case2 x => simp + | case3 x y zs ih_l ih_r => + simp only [eval_mergeSort_cons_cons] + exact (merge_perm oracle _ _).trans ((ih_l.append ih_r).trans (split_perm _)) + +-- ## Sortedness proofs + +/-- If `l` is a permutation of `xs ++ ys`, and `r a` holds for all elements of `xs` and `ys`, + then `r a` holds for all elements of `l`. -/ +private theorem forall_mem_of_perm_append {r : α → Prop} {l xs ys : List α} + (hperm : l.Perm (xs ++ ys)) + (hxs : ∀ z ∈ xs, r z) (hys : ∀ z ∈ ys, r z) : + ∀ z ∈ l, r z := by + intro z hz + rcases List.mem_append.mp (hperm.mem_iff.mp hz) with h | h + · exact hxs z h + · exact hys z h + +theorem merge_sorted + (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + (oracle : {ι : Type} → LEQuery α ι → ι) + (horacle : ∀ a b, oracle (.le a b) = decide (r a b)) + (xs ys : List α) (hxs : xs.Pairwise r) (hys : ys.Pairwise r) : + ((merge xs ys).eval oracle).Pairwise r := by + induction xs, ys using merge.induct (α := α) with + | case1 ys => simpa + | case2 xs => simpa + | case3 x xs' y ys' ih_true ih_false => + simp only [eval_merge_cons_cons, horacle] + have hxs' := hxs.of_cons + have hys' := hys.of_cons + split + next h => + have hle : r x y := by simpa [decide_eq_true_eq] using h + refine List.pairwise_cons.mpr ⟨?_, ih_true hxs' hys⟩ + exact forall_mem_of_perm_append (merge_perm oracle xs' (y :: ys')) + (fun _ hz => List.rel_of_pairwise_cons hxs hz) + (fun z hz => by + rcases List.mem_cons.mp hz with rfl | h + · exact hle + · exact _root_.trans hle (List.rel_of_pairwise_cons hys h)) + next h => + have hle : ¬ r x y := by simpa [decide_eq_true_eq] using h + have hyx : r y x := (Std.Total.total y x).resolve_right hle + refine List.pairwise_cons.mpr ⟨?_, ih_false hxs hys'⟩ + exact forall_mem_of_perm_append (merge_perm oracle (x :: xs') ys') + (fun z hz => by + rcases List.mem_cons.mp hz with rfl | h + · exact hyx + · exact _root_.trans hyx (List.rel_of_pairwise_cons hxs h)) + (fun _ hz => List.rel_of_pairwise_cons hys hz) + +theorem mergeSort_sorted + (r : α → α → Prop) [DecidableRel r] [Std.Total r] [IsTrans α r] + (oracle : {ι : Type} → LEQuery α ι → ι) + (horacle : ∀ a b, oracle (.le a b) = decide (r a b)) + (xs : List α) : + ((mergeSort xs).eval oracle).Pairwise r := by + induction xs using mergeSort.induct (α := α) with + | case1 => simp + | case2 x => simp + | case3 x y zs ih_l ih_r => + simp only [eval_mergeSort_cons_cons] + exact merge_sorted r oracle horacle _ _ ih_l ih_r + +-- ## Query count simp lemmas + +@[simp] theorem queriesOn_merge_nil_left (oracle : {ι : Type} → LEQuery α ι → ι) (ys : List α) : + (merge ([] : List α) ys).queriesOn oracle = 0 := by + simp [merge] + +@[simp] theorem queriesOn_merge_nil_right (oracle : {ι : Type} → LEQuery α ι → ι) (xs : List α) : + (merge xs ([] : List α)).queriesOn oracle = 0 := by + cases xs <;> simp [merge] + +@[simp] theorem queriesOn_merge_cons_cons (oracle : {ι : Type} → LEQuery α ι → ι) + (x : α) (xs' : List α) (y : α) (ys' : List α) : + (merge (x :: xs') (y :: ys')).queriesOn oracle = + 1 + if oracle (.le x y) + then (merge xs' (y :: ys')).queriesOn oracle + else (merge (x :: xs') ys').queriesOn oracle := by + simp [merge, LEQuery.ask] + split <;> simp_all + +@[simp] theorem queriesOn_mergeSort_nil (oracle : {ι : Type} → LEQuery α ι → ι) : + (mergeSort (α := α) []).queriesOn oracle = 0 := by + simp [mergeSort] + +@[simp] theorem queriesOn_mergeSort_singleton (oracle : {ι : Type} → LEQuery α ι → ι) (x : α) : + (mergeSort [x]).queriesOn oracle = 0 := by + simp [mergeSort] + +@[simp] theorem queriesOn_mergeSort_cons_cons (oracle : {ι : Type} → LEQuery α ι → ι) + (x y : α) (zs : List α) : + (mergeSort (x :: y :: zs)).queriesOn oracle = + (mergeSort (split (x :: y :: zs)).1).queriesOn oracle + + ((mergeSort (split (x :: y :: zs)).2).queriesOn oracle + + (merge ((mergeSort (split (x :: y :: zs)).1).eval oracle) + ((mergeSort (split (x :: y :: zs)).2).eval oracle)).queriesOn oracle) := by + simp [mergeSort] + +-- ## Query count proofs + +theorem merge_queriesOn_le (oracle : {ι : Type} → LEQuery α ι → ι) + (xs ys : List α) : + (merge xs ys).queriesOn oracle ≤ xs.length + ys.length := by + induction xs, ys using merge.induct (α := α) with + | case1 ys => simp + | case2 xs => simp + | case3 x xs' y ys' ih_true ih_false => + simp only [queriesOn_merge_cons_cons, List.length_cons] + split <;> simp_all <;> omega + +/-- The key arithmetic inequality for the merge sort recurrence: + `⌈n/2⌉ * clog(⌈n/2⌉) + ⌊n/2⌋ * clog(⌊n/2⌋) + n ≤ n * clog(n)`. -/ +private theorem mergeSort_bound (n : ℕ) (hn : 2 ≤ n) : + ((n + 1) / 2) * Nat.clog 2 ((n + 1) / 2) + + (n / 2 * Nat.clog 2 (n / 2) + ((n + 1) / 2 + n / 2)) ≤ + n * Nat.clog 2 n := by + have hclog := Nat.clog_of_one_lt (by omega : (1 : Nat) < 2) hn + have hceil : Nat.clog 2 ((n + 1) / 2) + 1 ≤ Nat.clog 2 n := le_of_eq hclog.symm + have hfloor : Nat.clog 2 (n / 2) + 1 ≤ Nat.clog 2 n := + (Nat.add_le_add_right (Nat.clog_mono_right 2 (by omega)) 1).trans hceil + have hsum : (n + 1) / 2 + n / 2 = n := by omega + have h1 := Nat.mul_le_mul_left ((n + 1) / 2) hceil + have h2 := Nat.mul_le_mul_left (n / 2) hfloor + rw [Nat.mul_succ] at h1 h2 + calc _ = ((n + 1) / 2 * Nat.clog 2 ((n + 1) / 2) + (n + 1) / 2) + + (n / 2 * Nat.clog 2 (n / 2) + n / 2) := by omega + _ ≤ (n + 1) / 2 * Nat.clog 2 n + n / 2 * Nat.clog 2 n := Nat.add_le_add h1 h2 + _ = ((n + 1) / 2 + n / 2) * Nat.clog 2 n := (Nat.add_mul ..).symm + _ = n * Nat.clog 2 n := by rw [hsum] + +theorem mergeSort_queriesOn_le (oracle : {ι : Type} → LEQuery α ι → ι) + (xs : List α) : + (mergeSort xs).queriesOn oracle ≤ xs.length * Nat.clog 2 xs.length := by + induction xs using mergeSort.induct (α := α) with + | case1 => simp [mergeSort] + | case2 x => simp [mergeSort] + | case3 x y zs ih_l ih_r => + simp only [queriesOn_mergeSort_cons_cons] + have hml := merge_queriesOn_le oracle + ((mergeSort (split (x :: y :: zs)).1).eval oracle) + ((mergeSort (split (x :: y :: zs)).2).eval oracle) + rw [(mergeSort_perm oracle (split (x :: y :: zs)).1).length_eq, + (mergeSort_perm oracle (split (x :: y :: zs)).2).length_eq, + split_fst_length_eq, split_snd_length_eq] at hml + rw [split_fst_length_eq] at ih_l + rw [split_snd_length_eq] at ih_r + exact Nat.le_trans (Nat.add_le_add ih_l (Nat.add_le_add ih_r hml)) + (mergeSort_bound _ (by simp only [List.length_cons]; omega)) + +-- ## UpperBound and IsSort instances + +public theorem mergeSort_upperBound : + UpperBound (mergeSort (α := α)) List.length (fun n => n * Nat.clog 2 n) := by + intro oracle n x hle + exact Nat.le_trans (mergeSort_queriesOn_le oracle x) + (Nat.mul_le_mul hle (Nat.clog_mono_right 2 hle)) + +public theorem mergeSort_isSort : IsSort (mergeSort (α := α)) where + perm xs oracle := mergeSort_perm oracle xs + sorted := by + intro xs oracle r _ _ _ horacle + exact mergeSort_sorted r oracle horacle xs + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean b/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean new file mode 100644 index 00000000..996d740d --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Sort/QueryTree.lean @@ -0,0 +1,107 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison, Sebastian Graf +-/ +module + +public import Cslib.Algorithms.Lean.Query.QueryTree +public import Mathlib.Data.Set.Function +public import Mathlib.Combinatorics.Pigeonhole + +/-! # Lower-Bound Lemma for Query Trees + +`QueryTree.exists_queriesOn_ge_clog`: if `n` oracles produce `n` distinct evaluation results +from a query tree with `Fintype` responses, then one of those oracles makes at least +`⌈log_{|R|} n⌉` queries. + +The proof uses the adversarial/partition argument: at each query node, the `n` oracles split by +their answer into `|R|` groups; the largest group (size ≥ ⌈n/|R|⌉) still produces distinct results +in the corresponding subtree, and the induction proceeds there. + +The proof works over an arbitrary `Finset ι` of oracle indices (avoiding re-indexing via +`Fintype.equivFin`), then derives the `Fin n` version as a corollary. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query.QueryTree + +/-- Finset-based version: if the oracles indexed by `S` produce `|S|`-many distinct evaluation + results, then some oracle in `S` makes at least `⌈log_{|R|} |S|⌉` queries. -/ +private theorem exists_mem_queriesOn_ge_clog [Fintype R] + {ι : Type} (t : QueryTree Q R α) (S : Finset ι) (hS : S.Nonempty) + (oracles : ι → (Q → R)) + (h_inj : Set.InjOn (fun i => t.eval (oracles i)) ↑S) : + ∃ i ∈ S, t.queriesOn (oracles i) ≥ Nat.clog (Fintype.card R) S.card := by + classical + induction t generalizing ι S with + | pure a => + obtain ⟨i, hi⟩ := hS + exact ⟨i, hi, by simp [queriesOn, Nat.clog_of_right_le_one + (Finset.card_le_one.mpr fun _ ha _ hb => h_inj ha hb rfl)]⟩ + | query q cont ih => + by_cases hle : S.card ≤ 1 + · obtain ⟨i, hi⟩ := hS; exact ⟨i, hi, by simp [Nat.clog_of_right_le_one hle]⟩ + · push_neg at hle + by_cases hR : Fintype.card R ≤ 1 + · obtain ⟨i, hi⟩ := hS; exact ⟨i, hi, by simp [Nat.clog_of_left_le_one hR]⟩ + · push_neg at hR + -- Find b : R such that S.filter (oracles · q = b) has ≥ ⌈|S|/|R|⌉ elements + have ⟨b, _, hb⟩ : ∃ b ∈ Finset.univ (α := R), + (S.card - 1) / Fintype.card R < (S.filter (fun i => oracles i q = b)).card := by + apply Finset.exists_lt_card_fiber_of_mul_lt_card_of_maps_to + (fun a _ => Finset.mem_univ (oracles a q)) + simp only [Finset.card_univ] + calc Fintype.card R * ((S.card - 1) / Fintype.card R) + = (S.card - 1) / Fintype.card R * Fintype.card R := Nat.mul_comm .. + _ ≤ S.card - 1 := Nat.div_mul_le_self _ _ + _ < S.card := by omega + set S' := S.filter (fun i => oracles i q = b) + have hS' : S'.Nonempty := + Finset.card_pos.mp (Nat.lt_of_le_of_lt (Nat.zero_le _) hb) + -- Restricted injectivity: eval through query q cont agrees with cont b on S' + have h_inj' : Set.InjOn (fun i => (cont b).eval (oracles i)) ↑S' := by + intro i hi j hj heq + have him := Finset.mem_coe.mp hi |> Finset.mem_filter.mp + have hjm := Finset.mem_coe.mp hj |> Finset.mem_filter.mp + exact h_inj (Finset.mem_coe.mpr him.1) (Finset.mem_coe.mpr hjm.1) + (by simp [eval, him.2, hjm.2, heq]) + obtain ⟨i, hi, hiq⟩ := ih b S' hS' oracles h_inj' + have him := Finset.mem_filter.mp hi + refine ⟨i, him.1, ?_⟩ + simp only [queriesOn_query, him.2] + calc Nat.clog (Fintype.card R) S.card + ≤ 1 + Nat.clog (Fintype.card R) S'.card := by + rw [Nat.clog_of_two_le (by omega) (by omega)] + have h_ceil : (S.card + Fintype.card R - 1) / Fintype.card R = + (S.card - 1) / Fintype.card R + 1 := by + rw [show S.card + Fintype.card R - 1 = S.card - 1 + Fintype.card R from by omega] + exact Nat.add_div_right (S.card - 1) (by omega) + have := Nat.clog_mono_right (Fintype.card R) + (show (S.card + Fintype.card R - 1) / Fintype.card R ≤ S'.card by omega) + omega + _ ≤ 1 + (cont b).queriesOn (oracles i) := by omega + +/-- If `n` oracles produce `n` distinct evaluation results from a query tree with `Fintype` + responses, then one of those oracles makes at least `⌈log_{|R|} n⌉` queries. + + This is the core combinatorial lemma for query complexity lower bounds. + The proof uses the adversarial/partition argument: at each query node, the `n` oracles + split by their answer to the query; the largest group (size ≥ ⌈n/|R|⌉) still produces + distinct results in the corresponding subtree, and the induction proceeds there. -/ +theorem exists_queriesOn_ge_clog [Fintype R] + (t : QueryTree Q R α) (oracles : Fin n → (Q → R)) + (hn : 0 < n) + (h_inj : Function.Injective (fun i => t.eval (oracles i))) : + ∃ i : Fin n, t.queriesOn (oracles i) ≥ Nat.clog (Fintype.card R) n := by + have ⟨i, _, hi⟩ := exists_mem_queriesOn_ge_clog t Finset.univ + (Finset.univ_nonempty_iff.mpr ⟨⟨0, hn⟩⟩) oracles (h_inj.injOn) + rw [Finset.card_univ, Fintype.card_fin] at hi + exact ⟨i, hi⟩ + +end Cslib.Query.QueryTree + +end -- public section From 44ed9f6b128668579d5cf34b9623fb19edf36b49 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Thu, 5 Mar 2026 09:01:36 +0000 Subject: [PATCH 2/3] feat(Query): add Prog.cost and complex multiplication example Add `Prog.cost`, a weighted generalization of `Prog.queriesOn` where each query type can have a different cost. Demonstrate this with complex multiplication: naive (4 muls + 2 adds) vs Gauss's trick (3 muls + 5 adds), proving correctness, exact parametric costs, and the crossover condition. Co-Authored-By: Claude Opus 4.6 --- Cslib.lean | 2 + Cslib/Algorithms/Lean/Query/Arith/Defs.lean | 80 +++++++++++++++++++ Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean | 68 ++++++++++++++++ Cslib/Algorithms/Lean/Query/Prog.lean | 33 ++++++++ 4 files changed, 183 insertions(+) create mode 100644 Cslib/Algorithms/Lean/Query/Arith/Defs.lean create mode 100644 Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean diff --git a/Cslib.lean b/Cslib.lean index dd44a787..b8f613a9 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -1,6 +1,8 @@ module -- shake: keep-all public import Cslib.Algorithms.Lean.MergeSort.MergeSort +public import Cslib.Algorithms.Lean.Query.Arith.Defs +public import Cslib.Algorithms.Lean.Query.Arith.Lemmas public import Cslib.Algorithms.Lean.Query.Bounds public import Cslib.Algorithms.Lean.Query.Prog public import Cslib.Algorithms.Lean.Query.QueryTree diff --git a/Cslib/Algorithms/Lean/Query/Arith/Defs.lean b/Cslib/Algorithms/Lean/Query/Arith/Defs.lean new file mode 100644 index 00000000..6d719089 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Arith/Defs.lean @@ -0,0 +1,80 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Prog + +/-! # Arithmetic Queries and Complex Multiplication + +Demonstrates the `Prog.cost` framework with non-uniform query costs. +`ArithQuery α` supports addition, subtraction, and multiplication, each with +independently parametrized costs. + +The motivating example is complex number multiplication, where two algorithms +(naive and Gauss's trick) trade multiplications for additions. With parametric +costs `c_add` and `c_mul`, the optimal choice depends on their ratio. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +/-- Arithmetic queries: addition, subtraction, and multiplication. -/ +inductive ArithQuery (α : Type) : Type → Type where + | add (a b : α) : ArithQuery α α + | sub (a b : α) : ArithQuery α α + | mul (a b : α) : ArithQuery α α + +namespace ArithQuery + +@[expose] def doAdd (a b : α) : Prog (ArithQuery α) α := .liftBind (.add a b) .pure +@[expose] def doSub (a b : α) : Prog (ArithQuery α) α := .liftBind (.sub a b) .pure +@[expose] def doMul (a b : α) : Prog (ArithQuery α) α := .liftBind (.mul a b) .pure + +/-- An honest oracle interprets arithmetic queries using the actual ring operations. -/ +@[expose] def honest [Add α] [Sub α] [Mul α] {ι : Type} : ArithQuery α ι → ι + | .add a b => a + b + | .sub a b => a - b + | .mul a b => a * b + +/-- Weighted cost model for arithmetic queries. Subtraction costs the same as addition + (both are linear-time on bignums). -/ +@[expose] def weight (c_add c_mul : Nat) {ι : Type} : ArithQuery α ι → Nat + | .add _ _ => c_add + | .sub _ _ => c_add + | .mul _ _ => c_mul + +end ArithQuery + +/-- Naive complex multiplication: `(a + bi)(c + di) = (ac - bd) + (ad + bc)i`. + Uses 4 multiplications, 1 subtraction, 1 addition. -/ +@[expose] def complexMulNaive (a b c d : α) : Prog (ArithQuery α) (α × α) := do + let ac ← ArithQuery.doMul a c + let bd ← ArithQuery.doMul b d + let ad ← ArithQuery.doMul a d + let bc ← ArithQuery.doMul b c + let real ← ArithQuery.doSub ac bd + let imag ← ArithQuery.doAdd ad bc + pure (real, imag) + +/-- Gauss's trick for complex multiplication: computes `(a+b)(c+d)` to save one + multiplication, at the cost of extra additions and subtractions. + Uses 3 multiplications, 2 subtractions, 2 additions. -/ +@[expose] def complexMulGauss (a b c d : α) : Prog (ArithQuery α) (α × α) := do + let ac ← ArithQuery.doMul a c + let bd ← ArithQuery.doMul b d + let apb ← ArithQuery.doAdd a b + let cpd ← ArithQuery.doAdd c d + let abcd ← ArithQuery.doMul apb cpd + let real ← ArithQuery.doSub ac bd + let imag ← ArithQuery.doSub abcd (← ArithQuery.doAdd ac bd) + pure (real, imag) + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean new file mode 100644 index 00000000..69b1fad6 --- /dev/null +++ b/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean @@ -0,0 +1,68 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +public import Cslib.Algorithms.Lean.Query.Arith.Defs +import Mathlib.Tactic.Ring +public import Mathlib.Algebra.Ring.Defs + +/-! # Complex Multiplication: Correctness and Cost Analysis + +We prove that both `complexMulNaive` and `complexMulGauss` correctly compute +complex multiplication when given an honest oracle, and compute their exact +costs under a parametric weight function. + +The cost theorems hold for *any* oracle (not just honest ones), because both +algorithms are straight-line (no branching on query results). The correctness +theorems require the honest oracle. +-/ + +open Cslib.Query + +public section + +namespace Cslib.Query + +variable {α : Type} + +-- ## Correctness + +theorem complexMulNaive_eval_honest [Ring α] (a b c d : α) : + (complexMulNaive a b c d).eval ArithQuery.honest = (a * c - b * d, a * d + b * c) := by + simp [complexMulNaive, ArithQuery.doMul, ArithQuery.doSub, ArithQuery.doAdd, ArithQuery.honest] + +theorem complexMulGauss_eval_honest [CommRing α] (a b c d : α) : + (complexMulGauss a b c d).eval ArithQuery.honest = (a * c - b * d, a * d + b * c) := by + simp [complexMulGauss, ArithQuery.doMul, ArithQuery.doSub, ArithQuery.doAdd, ArithQuery.honest] + ring + +-- ## Exact cost counts + +theorem complexMulNaive_cost (oracle : {ι : Type} → ArithQuery α ι → ι) + (c_add c_mul : Nat) (a b c d : α) : + (complexMulNaive a b c d).cost oracle (ArithQuery.weight c_add c_mul) = + 4 * c_mul + 2 * c_add := by + simp [complexMulNaive, ArithQuery.doMul, ArithQuery.doSub, ArithQuery.doAdd, ArithQuery.weight] + omega + +theorem complexMulGauss_cost (oracle : {ι : Type} → ArithQuery α ι → ι) + (c_add c_mul : Nat) (a b c d : α) : + (complexMulGauss a b c d).cost oracle (ArithQuery.weight c_add c_mul) = + 3 * c_mul + 5 * c_add := by + simp [complexMulGauss, ArithQuery.doMul, ArithQuery.doSub, ArithQuery.doAdd, ArithQuery.weight] + omega + +-- ## Crossover: Gauss beats naive when multiplication costs more than 3× addition + +theorem gauss_le_naive (c_add c_mul : Nat) (h : 3 * c_add ≤ c_mul) : + 3 * c_mul + 5 * c_add ≤ 4 * c_mul + 2 * c_add := by omega + +theorem gauss_le_naive_iff (c_add c_mul : Nat) : + 3 * c_mul + 5 * c_add ≤ 4 * c_mul + 2 * c_add ↔ 3 * c_add ≤ c_mul := by omega + +end Cslib.Query + +end -- public section diff --git a/Cslib/Algorithms/Lean/Query/Prog.lean b/Cslib/Algorithms/Lean/Query/Prog.lean index 638096d0..d6b0f1da 100644 --- a/Cslib/Algorithms/Lean/Query/Prog.lean +++ b/Cslib/Algorithms/Lean/Query/Prog.lean @@ -85,6 +85,39 @@ variable {Q : Type → Type} {α β : Type} simp only [FreeM.bind, queriesOn_liftBind, eval_liftBind, ih (oracle op)] omega +/-- Weighted query cost: each query has a cost given by `weight`. -/ +@[expose] def cost (oracle : {ι : Type} → Q ι → ι) + (weight : {ι : Type} → Q ι → Nat) : Prog Q α → Nat + | .pure _ => 0 + | .liftBind op cont => weight op + cost oracle weight (cont (oracle op)) + +-- Simp lemmas for cost + +@[simp] theorem cost_pure (oracle : {ι : Type} → Q ι → ι) + (weight : {ι : Type} → Q ι → Nat) (a : α) : + cost oracle weight (.pure a : Prog Q α) = 0 := rfl + +@[simp] theorem cost_liftBind (oracle : {ι : Type} → Q ι → ι) + (weight : {ι : Type} → Q ι → Nat) {ι : Type} (op : Q ι) (cont : ι → Prog Q α) : + cost oracle weight (.liftBind op cont) = + weight op + cost oracle weight (cont (oracle op)) := rfl + +@[simp] theorem cost_bind (oracle : {ι : Type} → Q ι → ι) + (weight : {ι : Type} → Q ι → Nat) (t : Prog Q α) (f : α → Prog Q β) : + cost oracle weight (t.bind f) = + cost oracle weight t + cost oracle weight (f (eval oracle t)) := by + induction t with + | pure a => simp [FreeM.bind] + | liftBind op cont ih => + simp only [FreeM.bind, cost_liftBind, eval_liftBind, ih (oracle op)] + omega + +theorem queriesOn_eq_cost_one (oracle : {ι : Type} → Q ι → ι) (p : Prog Q α) : + queriesOn oracle p = cost oracle (fun _ => 1) p := by + induction p with + | pure a => rfl + | liftBind op cont ih => simp [ih (oracle op)] + end Prog end Cslib.Query From 83d95c4be1a88819b41eca1fe88d9539fd54d934 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Thu, 5 Mar 2026 09:06:52 +0000 Subject: [PATCH 3/3] docs(Query/Arith): clarify these are toy examples of parametrized costs Co-Authored-By: Claude Opus 4.6 --- Cslib/Algorithms/Lean/Query/Arith/Defs.lean | 11 +++++------ Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/Cslib/Algorithms/Lean/Query/Arith/Defs.lean b/Cslib/Algorithms/Lean/Query/Arith/Defs.lean index 6d719089..34c2179f 100644 --- a/Cslib/Algorithms/Lean/Query/Arith/Defs.lean +++ b/Cslib/Algorithms/Lean/Query/Arith/Defs.lean @@ -9,13 +9,12 @@ public import Cslib.Algorithms.Lean.Query.Prog /-! # Arithmetic Queries and Complex Multiplication -Demonstrates the `Prog.cost` framework with non-uniform query costs. -`ArithQuery α` supports addition, subtraction, and multiplication, each with -independently parametrized costs. +A simple example showing how to use `Prog.cost` with variable/parametrized query costs. -The motivating example is complex number multiplication, where two algorithms -(naive and Gauss's trick) trade multiplications for additions. With parametric -costs `c_add` and `c_mul`, the optimal choice depends on their ratio. +`ArithQuery α` supports addition, subtraction, and multiplication, each with +independently parametrized costs. Complex number multiplication provides a toy example +where two algorithms (naive and Gauss's trick) trade multiplications for additions, +and the optimal choice depends on the cost ratio. -/ open Cslib.Query diff --git a/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean b/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean index 69b1fad6..75ce0a85 100644 --- a/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean +++ b/Cslib/Algorithms/Lean/Query/Arith/Lemmas.lean @@ -11,13 +11,13 @@ public import Mathlib.Algebra.Ring.Defs /-! # Complex Multiplication: Correctness and Cost Analysis +A simple example showing how to use `Prog.cost` with variable/parametrized query costs. + We prove that both `complexMulNaive` and `complexMulGauss` correctly compute complex multiplication when given an honest oracle, and compute their exact -costs under a parametric weight function. - -The cost theorems hold for *any* oracle (not just honest ones), because both -algorithms are straight-line (no branching on query results). The correctness -theorems require the honest oracle. +costs under a parametric weight function. The cost theorems hold for *any* oracle +(not just honest ones), because both algorithms are straight-line (no branching +on query results). -/ open Cslib.Query