From 9e8df3c5b62d198f5c2923f6f98354e56ddd2d5e Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 26 Sep 2025 13:26:27 +0100 Subject: [PATCH 1/5] chore: add stream differential equations --- Blase/Blase/MultiWidth/StreamDiffEq.lean | 0 Blase/Blase/MultiWidth/Tactic.lean | 1 + 2 files changed, 1 insertion(+) create mode 100644 Blase/Blase/MultiWidth/StreamDiffEq.lean diff --git a/Blase/Blase/MultiWidth/StreamDiffEq.lean b/Blase/Blase/MultiWidth/StreamDiffEq.lean new file mode 100644 index 0000000000..e69de29bb2 diff --git a/Blase/Blase/MultiWidth/Tactic.lean b/Blase/Blase/MultiWidth/Tactic.lean index dee3463e43..687935bc2b 100644 --- a/Blase/Blase/MultiWidth/Tactic.lean +++ b/Blase/Blase/MultiWidth/Tactic.lean @@ -4,6 +4,7 @@ import Blase.MultiWidth.GoodFSM import Blase.MultiWidth.Preprocessing import Blase.KInduction.KInduction import Blase.AutoStructs.FormulaToAuto +import Blase.StreamDiffEq import Blase.ReflectMap initialize Lean.registerTraceClass `Bits.MultiWidth From 0f6a0d816ae4d27798a28130cc23fb9eac43844d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 30 Sep 2025 10:38:15 +0100 Subject: [PATCH 2/5] chore: blase stream diff eqn --- Blase/Blase/MultiWidth/StreamDiffEq.lean | 48 ++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/Blase/Blase/MultiWidth/StreamDiffEq.lean b/Blase/Blase/MultiWidth/StreamDiffEq.lean index e69de29bb2..20869aa0d0 100644 --- a/Blase/Blase/MultiWidth/StreamDiffEq.lean +++ b/Blase/Blase/MultiWidth/StreamDiffEq.lean @@ -0,0 +1,48 @@ +import Blase.Fast.Circuit +import Blase.Fast.FiniteStateMachine + + + +inductive StreamVar (ι : Type) (npast : Nat) where +| input (i : ι) (kpast : Fin npast) +| output (kpast : Fin npast) (h : 0 < kpast.val) +deriving DecidableEq, Repr, Hashable + +instance [DecidableEq ι] : FinEnum (StreamVar ι npast) where + card := sorry + equiv := sorry + +-- out[0] = ... +-- a stream equation, that says that out[n] is a function of the inputs ι, and the previous outputs out[0], ..., out[n-1] +-- stream differential equation with 'ι' inputs, and +structure StreamDiffEq (ι : Type) (npast : Nat) where + -- | Assumed values of the output for the horizon. + outInit : (kpast : Fin npast) → (h : 0 < kpast.val) → Bool + -- | compute output as a function of past outputs and inputs + outCircuit : Circuit (StreamVar ι npast) + + +def StreamDiffEq.toStream [DecidableEq ι] (s : StreamDiffEq ι npast) (input : ι → BitStream) : BitStream := + let rec go (n : Nat) : Bool := + if n < npast then + s.outInit (Fin.mk n (by omega)) (by omega) + else + let circInput : FinEnumVal (StreamVar ι npast) → Bool := + fun v => match v with + | .input i kpast => input.get (n - kpast.val - 1) + | .output kpast h => go (n - kpast.val - 1) + s.outCircuit.eval circInput + ⟨go⟩ + +def StreamDiffEq.toFSM [DecidableEq ι] [Hashable ι] (s : StreamDiffEq ι npast) : FSM ι where + α := StreamVar ι npast + initCarry + | .input i kpast => false + | .output kpast h => s.outInit kpast (by omega) + outputCirc := sorry + nextStateCirc := sorry + + + + + From 26c1e10f7509bf79017863edc01df3271526eeb2 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Tue, 7 Oct 2025 23:30:41 +0100 Subject: [PATCH 3/5] chore: adding more theory --- Blase/Blase/MultiWidth/StreamDiffEq.lean | 188 +++++++++++++++++++---- 1 file changed, 155 insertions(+), 33 deletions(-) diff --git a/Blase/Blase/MultiWidth/StreamDiffEq.lean b/Blase/Blase/MultiWidth/StreamDiffEq.lean index 20869aa0d0..d67aa4db9d 100644 --- a/Blase/Blase/MultiWidth/StreamDiffEq.lean +++ b/Blase/Blase/MultiWidth/StreamDiffEq.lean @@ -1,48 +1,170 @@ import Blase.Fast.Circuit import Blase.Fast.FiniteStateMachine +structure InputVar (ι : Type) (npast : Nat) where + input : ι + past : Fin npast +deriving DecidableEq, Hashable +private theorem Nat.mul_lt_mul_add {N M : Nat} {x y : Nat} + (hx : x < N) (hy : y < M) : x * M + y < N * M := by + have : 0 < N := by omega + have : N * M ≥ M := by + rcases N with rfl | N + · omega + · simp; rw [Nat.add_mul]; omega + have : y ≤ M - 1 := by omega + have : x * M ≤ N * M - M := by + apply le_sub_of_add_le + rw [← add_one_mul] + apply mul_le_mul_right + omega + have : x * M + y ≤ (N * M - M) + (M - 1) := by omega + have : x * M + y ≤ N * M - 1 := by + apply Nat.le_trans this + omega + omega -inductive StreamVar (ι : Type) (npast : Nat) where -| input (i : ι) (kpast : Fin npast) -| output (kpast : Fin npast) (h : 0 < kpast.val) -deriving DecidableEq, Repr, Hashable +@[simp] +theorem FinEnum_card_eq_self : + FinEnum.card (Fin npast) = npast := by + simp only [FinEnum.card_fin] -instance [DecidableEq ι] : FinEnum (StreamVar ι npast) where - card := sorry - equiv := sorry +instance [instFinEnumI : FinEnum ι] : FinEnum (InputVar ι npast) where + card := + let instFinEnumFinNpast : FinEnum (Fin npast) := by infer_instance + instFinEnumI.card * instFinEnumFinNpast.card + equiv := { + toFun := fun { input, past } => + let instFinEnumFinNpast : FinEnum (Fin npast) := by infer_instance + let finInput := instFinEnumI.equiv.toFun input + let finPast := instFinEnumFinNpast.equiv.toFun past + + ⟨finInput.val * instFinEnumFinNpast.card + finPast.val, by + apply Nat.mul_lt_mul_add + · sorry + · simp only [FinEnum.card_fin] + have := finPast.isLt + simpa using this + ⟩ + invFun := fun x => sorry + left_inv := sorry + right_inv := sorry + } --- out[0] = ... --- a stream equation, that says that out[n] is a function of the inputs ι, and the previous outputs out[0], ..., out[n-1] --- stream differential equation with 'ι' inputs, and structure StreamDiffEq (ι : Type) (npast : Nat) where - -- | Assumed values of the output for the horizon. - outInit : (kpast : Fin npast) → (h : 0 < kpast.val) → Bool - -- | compute output as a function of past outputs and inputs - outCircuit : Circuit (StreamVar ι npast) + inInit : InputVar ι npast → Bool + outCircuit : Circuit (InputVar ι npast) + + +/-- Cons a value onto a bitstream. -/ +def BitStream.cons (b : Bool) (bs : BitStream) : BitStream := + fun n => + match n with + | 0 => b + | n + 1 => bs n + +@[simp] +theorem BitStream.cons_zero (b : Bool) (bs : BitStream) : + (BitStream.cons b bs) 0 = b := rfl + +@[simp] +theorem BitStream.cons_succ (b : Bool) (bs : BitStream) (n : Nat) : + (BitStream.cons b bs) (n + 1) = bs n := rfl + + +/-- Append n bits to the left of a bitstream. -/ +def BitStream.appendLeft (b : BitStream) + (n : Nat) + (env : Fin n → Bool) : BitStream := + match n with + | 0 => b + | n + 1 => BitStream.cons (env 0) (b.appendLeft n (fun k => env k.succ)) + +@[simp] +theorem BitStream.appendLeft_zero (b : BitStream) (env : Fin 0 → Bool) : + b.appendLeft 0 env = b := rfl + + +@[simp] +theorem BitStream.appendLeft_succ (b : BitStream) (n : Nat) + (env : Fin (n + 1) → Bool) : + b.appendLeft (n + 1) env = + BitStream.cons (env 0) (b.appendLeft n (fun k => env k.succ)) := rfl + +@[simp] +theorem BitStream.appendLeft_eq_ite (b : BitStream) (n : Nat) + (env : Fin n → Bool) (k : Nat) : + (b.appendLeft n env) k = + if h : k < n then env ⟨k, h⟩ else b (k - n) := by + induction n generalizing k with + | zero => + simp [BitStream.appendLeft] + | succ n ihn => + simp + induction k + case zero => simp + case succ k ihk => + simp [ihn] + +/-- Lift a circuit to a bitstream by pointwise evaluation. -/ +def BitStream.ofCircuitPointwise(circ : Circuit α) (env : α → BitStream) : BitStream := + fun n => circ.eval (fun a => env a n) +@[simp] +theorem eval_ofCircuitPointwise (circ : Circuit α) (env : α → BitStream) (n : Nat) : + (BitStream.ofCircuitPointwise circ env) n = + circ.eval (fun a => env a n) := rfl -def StreamDiffEq.toStream [DecidableEq ι] (s : StreamDiffEq ι npast) (input : ι → BitStream) : BitStream := - let rec go (n : Nat) : Bool := - if n < npast then - s.outInit (Fin.mk n (by omega)) (by omega) - else - let circInput : FinEnumVal (StreamVar ι npast) → Bool := - fun v => match v with - | .input i kpast => input.get (n - kpast.val - 1) - | .output kpast h => go (n - kpast.val - 1) - s.outCircuit.eval circInput - ⟨go⟩ +/-- Drop the first n bits of a bitstream. -/ +def BitStream.drop (bs : BitStream) (n : Nat) : BitStream := + fun k => bs (k + n) -def StreamDiffEq.toFSM [DecidableEq ι] [Hashable ι] (s : StreamDiffEq ι npast) : FSM ι where - α := StreamVar ι npast - initCarry - | .input i kpast => false - | .output kpast h => s.outInit kpast (by omega) - outputCirc := sorry - nextStateCirc := sorry +@[simp] +theorem BitStream.eval_drop (bs : BitStream) (n k : Nat) : + (BitStream.drop bs n) k = bs (k + n) := rfl +/-- +Produce the output stream differential equation as a bitstream. +-/ +def StreamDiffEq.toStream [DecidableEq ι] + (s : StreamDiffEq ι npast) (inputStream : ι → BitStream) : BitStream := + let newStreams : ι → BitStream := fun i => + BitStream.appendLeft (inputStream i) npast + (fun kpast => s.inInit ⟨i, kpast⟩) + BitStream.ofCircuitPointwise s.outCircuit fun iv => + (newStreams iv.input).drop iv.past - +@[simp] +theorem StreamDiffEq.toStream_eq_eval_of_lt [DecidableEq ι] + (s : StreamDiffEq ι npast) + (env : InputVar ι npast → Bool) + (inputStream : ι → BitStream) (n : Nat) (hn : ix < npast) : + (s.toStream inputStream) ix = s.outCircuit.eval env := by + simp [StreamDiffEq.toStream, BitStream.eval_drop] + congr + ext i + sorry +/-- +Produce the output stream differential equation as a FSM +-/ +def StreamDiffEq.toFSM [DecidableEq ι] [Hashable ι] [FinEnum ι] + (s : StreamDiffEq ι npast) : FSM ι where + α := InputVar ι npast + initCarry := s.inInit + outputCirc := s.outCircuit.map Sum.inl + -- | we need to rotate, and send bits to the more past state. + nextStateCirc := fun iv => + if h : iv.past.val = 0 then + Circuit.var true <| .inr iv.input + else + Circuit.var true <| .inl ⟨iv.input, ⟨iv.past.val - 1, by omega⟩⟩ +theorem StreamDiffEq.toFsm_eval_eq_toStream [DecidableEq ι] [Hashable ι] + [FinEnum ι] (s : StreamDiffEq ι npast) (inputStream : ι → BitStream) : + (s.toFSM.eval inputStream) = (s.toStream inputStream) := by + ext i + by_cases hi : i < npast + · simp + · sorry From a624a2bfe4cc2222a42a82c9d7c38b9fe735d204 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 13 Oct 2025 15:16:52 +0800 Subject: [PATCH 4/5] chore: stream diff eq proof --- Blase/Blase/MultiWidth/StreamDiffEq.lean | 169 ++++++++++++++++++----- 1 file changed, 137 insertions(+), 32 deletions(-) diff --git a/Blase/Blase/MultiWidth/StreamDiffEq.lean b/Blase/Blase/MultiWidth/StreamDiffEq.lean index d67aa4db9d..c820a53ac3 100644 --- a/Blase/Blase/MultiWidth/StreamDiffEq.lean +++ b/Blase/Blase/MultiWidth/StreamDiffEq.lean @@ -2,7 +2,7 @@ import Blase.Fast.Circuit import Blase.Fast.FiniteStateMachine structure InputVar (ι : Type) (npast : Nat) where - input : ι + inputIx : ι past : Fin npast deriving DecidableEq, Hashable @@ -35,9 +35,9 @@ instance [instFinEnumI : FinEnum ι] : FinEnum (InputVar ι npast) where let instFinEnumFinNpast : FinEnum (Fin npast) := by infer_instance instFinEnumI.card * instFinEnumFinNpast.card equiv := { - toFun := fun { input, past } => + toFun := fun { inputIx, past } => let instFinEnumFinNpast : FinEnum (Fin npast) := by infer_instance - let finInput := instFinEnumI.equiv.toFun input + let finInput := instFinEnumI.equiv.toFun inputIx let finPast := instFinEnumFinNpast.equiv.toFun past ⟨finInput.val * instFinEnumFinNpast.card + finPast.val, by @@ -53,7 +53,9 @@ instance [instFinEnumI : FinEnum ι] : FinEnum (InputVar ι npast) where } structure StreamDiffEq (ι : Type) (npast : Nat) where - inInit : InputVar ι npast → Bool + /-- The output value for the first n steps. -/ + initialOutputVal : Fin npast → Bool + /-- The output as a circuit of the past 'npast' inputs. -/ outCircuit : Circuit (InputVar ι npast) @@ -74,9 +76,7 @@ theorem BitStream.cons_succ (b : Bool) (bs : BitStream) (n : Nat) : /-- Append n bits to the left of a bitstream. -/ -def BitStream.appendLeft (b : BitStream) - (n : Nat) - (env : Fin n → Bool) : BitStream := +def BitStream.appendLeft (n : Nat) (env : Fin n → Bool) (b : BitStream) : BitStream := match n with | 0 => b | n + 1 => BitStream.cons (env 0) (b.appendLeft n (fun k => env k.succ)) @@ -129,22 +129,29 @@ Produce the output stream differential equation as a bitstream. -/ def StreamDiffEq.toStream [DecidableEq ι] (s : StreamDiffEq ι npast) (inputStream : ι → BitStream) : BitStream := - let newStreams : ι → BitStream := fun i => - BitStream.appendLeft (inputStream i) npast - (fun kpast => s.inInit ⟨i, kpast⟩) - BitStream.ofCircuitPointwise s.outCircuit fun iv => - (newStreams iv.input).drop iv.past + fun k => + if h : k < npast then + s.initialOutputVal ⟨k, by omega⟩ + else + s.outCircuit.eval (fun input => + inputStream input.inputIx (k - input.past.val) + ) + + +@[simp] +theorem StreamDiffEq.toStream_eq_initialOutputVal_of_lt [DecidableEq ι] + (s : StreamDiffEq ι npast) + (inputStream : ι → BitStream) (hn : k < npast) : + (s.toStream inputStream) k = s.initialOutputVal ⟨k, hn⟩ := by + simp [StreamDiffEq.toStream, hn] @[simp] -theorem StreamDiffEq.toStream_eq_eval_of_lt [DecidableEq ι] +theorem StreamDiffEq.toStream_eq_eval_of_le [DecidableEq ι] (s : StreamDiffEq ι npast) - (env : InputVar ι npast → Bool) - (inputStream : ι → BitStream) (n : Nat) (hn : ix < npast) : - (s.toStream inputStream) ix = s.outCircuit.eval env := by - simp [StreamDiffEq.toStream, BitStream.eval_drop] - congr - ext i - sorry + (inputStream : ι → BitStream) (hn : npast ≤ k) : + (s.toStream inputStream) k = + s.outCircuit.eval (fun input => inputStream input.inputIx (k - input.past.val) ) := by + simp [StreamDiffEq.toStream, hn] /-- Produce the output stream differential equation as a FSM @@ -152,19 +159,117 @@ Produce the output stream differential equation as a FSM def StreamDiffEq.toFSM [DecidableEq ι] [Hashable ι] [FinEnum ι] (s : StreamDiffEq ι npast) : FSM ι where α := InputVar ι npast - initCarry := s.inInit - outputCirc := s.outCircuit.map Sum.inl + initCarry := fun _ => true + -- | completely ignore the current input? + -- | No, that seems wrong! Instead, we just need to output the current circuit output. + -- What we should do, is for the first N steps, output whatever the output + -- says, and after that, just keep outputting the circuit output. + outputCirc := + s.outCircuit.map fun iv => + if h : iv.past.val = 0 then + .inr iv.inputIx + else + .inl (InputVar.mk (iv.inputIx) ⟨iv.past.val - 1, by omega⟩) -- | we need to rotate, and send bits to the more past state. nextStateCirc := fun iv => if h : iv.past.val = 0 then - Circuit.var true <| .inr iv.input + Circuit.var true <| .inr iv.inputIx else - Circuit.var true <| .inl ⟨iv.input, ⟨iv.past.val - 1, by omega⟩⟩ - -theorem StreamDiffEq.toFsm_eval_eq_toStream [DecidableEq ι] [Hashable ι] - [FinEnum ι] (s : StreamDiffEq ι npast) (inputStream : ι → BitStream) : - (s.toFSM.eval inputStream) = (s.toStream inputStream) := by - ext i - by_cases hi : i < npast - · simp - · sorry + Circuit.var true <| .inl ⟨iv.inputIx, ⟨iv.past.val - 1, by omega⟩⟩ + +/-- +Make an FSM that overrides the output of another FSM for one clock cycle +to a constant value. +-/ +def fsmOverrideOutput (f : FSM arity) (b : Bool) : FSM arity where + α := Unit ⊕ f.α + initCarry := fun i => + match i with + | .inl () => true + | .inr a => f.initCarry a + outputCirc := + Circuit.ite (Circuit.var true <| .inl (.inl ())) + (Circuit.ofBool b) + (f.outputCirc.map fun v => + match v with + | .inl fa => .inl (.inr fa) + | .inr a => .inr a) + nextStateCirc := fun i => + match i with + | .inl () => .ofBool false -- make 'false'. + | .inr a => + (f.nextStateCirc a).map fun v => + match v with + | .inl fa => .inl (.inr fa) + | .inr a => .inr a + +@[simp] +theorem eval_FsmOverrideOutput_zero + {f : FSM arity} {b : Bool} {env : arity → BitStream} : + (fsmOverrideOutput f b).eval env 0 = b := by + simp [FSM.eval, fsmOverrideOutput, FSM.nextBit] + split_ifs + case pos h => simp [h] + case neg h => simp [h] + +@[simp] +theorem carry_fsmOverrideOutput_eq + {f : FSM arity} {b : Bool} {env : arity → BitStream} : + ∀ (a : f.α), ((fsmOverrideOutput f b).carry env n) (.inr a) = (f.carry env n) a := by + induction n generalizing env b + case zero => + intros a + simp [fsmOverrideOutput, FSM.carry, FSM.nextBit, Circuit.eval_map] + case succ n ihn => + intros a + simp [fsmOverrideOutput, FSM.carry, FSM.nextBit, Circuit.eval_map] + congr + ext i + rcases i with a | i + · simp only [Sum.elim_inl] + rw [← ihn (env := env) (b := b)] + simp [fsmOverrideOutput] + · simp + +@[simp] +theorem eval_FsmOverrideOutput_succ {f : FSM arity} {b : Bool} : + (fsmOverrideOutput f b).eval env n = + if n = 0 then b else f.eval env n := by + -- | TODO: replace all FSM proofs with `eval_induction_1`? + -- TODO: Write about this reasoning principle in the paper. + induction n using FSM.eval_induction_1 + (fsm := fsmOverrideOutput f b) + (inputs := env) + (SInv := fun (i : Nat) (state : Unit ⊕ f.α → Bool) => + (∀ a, state (.inr a) = (f.carry env i) a) ∧ (state (.inl ()) = decide (i = 0))) + case hstate0 => + constructor + · intros a + simp [fsmOverrideOutput] + · simp [fsmOverrideOutput] + case hStateSucc k state ih => + simp [fsmOverrideOutput, FSM.nextBitState, FSM.nextBit] + simp [FSM.carry] + simp [FSM.nextBit, Circuit.eval_map] + intros a + congr + ext i + rcases i with a | i + · simp [ih] + · simp + case hEval k state ih => + simp [fsmOverrideOutput, FSM.nextBitOutput, FSM.nextBit] + obtain ⟨ih1, ih2⟩ := ih + simp [ih2] + split + case isTrue hk => + subst hk + rcases hb : b <;> simp + case isFalse hk => + simp [Circuit.eval_map] + simp [FSM.eval, FSM.nextBit] + congr + ext i + rcases i with a | i + · simp [ih1] + · simp From 51c1d24df1442b4a31ae21a4050f9ea96b26d4ce Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Fri, 17 Oct 2025 17:41:10 +0800 Subject: [PATCH 5/5] chore: fix proof --- Blase/Blase/MultiWidth/StreamDiffEq.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Blase/Blase/MultiWidth/StreamDiffEq.lean b/Blase/Blase/MultiWidth/StreamDiffEq.lean index c820a53ac3..4278c1b9fb 100644 --- a/Blase/Blase/MultiWidth/StreamDiffEq.lean +++ b/Blase/Blase/MultiWidth/StreamDiffEq.lean @@ -42,7 +42,7 @@ instance [instFinEnumI : FinEnum ι] : FinEnum (InputVar ι npast) where ⟨finInput.val * instFinEnumFinNpast.card + finPast.val, by apply Nat.mul_lt_mul_add - · sorry + · grind · simp only [FinEnum.card_fin] have := finPast.isLt simpa using this