Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ jobs:
echo "ALIVE_ALL=$(cat all)" >> $GITHUB_ENV
echo "ALIVE_FAILED=$(cat failed)" >> $GITHUB_ENV

- name: Compile Alive Examples (With UB)
run: |
lake -R build SSA.Projects.SLLVM.Evaluation.AliveAutoGeneratedCopy

- name: Compile DC
run: |
lake -R build SSA.Projects.CIRCT.Handshake.Handshake # compile and check CIRCT's Handshake Dialect
Expand Down
1 change: 1 addition & 0 deletions SSA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import SSA.Projects.CSE.CSE
import SSA.Projects.PaperExamples.PaperExamples
import SSA.Projects.Scf.ScfFunctor
import SSA.Projects.LeanMlirCommon.LeanMlirCommon
import SSA.Projects.SLLVM.SLLVM


-- EXPERIMENTAL
Expand Down
4 changes: 2 additions & 2 deletions SSA/Core/Framework.lean
Original file line number Diff line number Diff line change
Expand Up @@ -803,15 +803,15 @@ variable [DialectHRefinement d d]
An expression `e₁` is refined by an expression `e₂` (of the same dialect) if their
respective denotations under every valuation are in the refinement relation.
-/
instance: Refinement (Expr d Γ eff t) where
instance: HRefinement (Expr d Γ eff₁ t) (Expr d Γ eff₂ t) where
IsRefinedBy e₁ e₂ :=
∀ V, e₁.denote V ⊑ e₂.denote V

/--
A program `c₁` is refined by a program `c₂` (of the same dialect) if their
respective denotations under every valuation are in the refinement relation.
-/
instance: Refinement (Com d Γ eff t) where
instance : HRefinement (Com d Γ eff₁ t) (Com d Γ eff₂ t) where
IsRefinedBy c₁ c₂ :=
∀ V, c₁.denote V ⊑ c₂.denote V

Expand Down
9 changes: 9 additions & 0 deletions SSA/Core/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,17 @@ macro "simp_peephole" loc:(location)? : tactic =>
| rw [funext_iff (α := Ctxt.Valuation _)] $[$loc]?
| change ∀ (_ : Ctxt.Valuation _), _ $[$loc]?
| skip

-- Then, we simplify with the `simp_denote` simpset
simp (config := {failIfUnchanged := false}) only [simp_denote] $[$loc]?
))

/-
TODO: Implement a check for `LawfulMonad`.
If a dialect's monad doesn't implement LawfulMonad, then currently `simp_peephole`
just silently fails to apply `Com.denote_var`, leaving the goal state messy
without any indication on what's wrong. We should catch this, and show a failed
to synthesize error.
-/

end SSA
51 changes: 48 additions & 3 deletions SSA/Core/Util/Poison.lean
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,24 @@ def casesOn'.{u} {α : Type} {motive : PoisonOr α → Sort u}
| .poison => poison
| .value a => value a

@[simp] theorem value_inj {a b : α} : value a = value b ↔ a = b := by
@[simp] theorem value_inj {a b : α} :
@Eq (no_index _) (value a) (value b) ↔ a = b := by
-- ^^ `value a = value b ↔ _`
constructor
· rintro ⟨⟩; rfl
· exact fun h => h ▸ rfl

theorem poison_ne_value (a : α) : poison ≠ value a := by rintro ⟨⟩
theorem value_ne_poison (a : α) : value a ≠ poison := by rintro ⟨⟩
theorem poison_ne_value (a : α) :
@Ne (no_index _) poison (value a) := by -- `poison ≠ value a`
rintro ⟨⟩
theorem value_ne_poison (a : α) :
@Ne (no_index _) (value a) poison := by -- `value a ≠ poison`
rintro ⟨⟩

@[simp]
theorem ite_value_value {c : Prop} [Decidable c] {a b : α} :
(if c then value a else value b : no_index _) = value (if c then a else b) := by
split <;> rfl

/-! ### Formatting & Priting instances -/
instance [ToString α] : ToString (PoisonOr α) where
Expand Down Expand Up @@ -201,6 +212,40 @@ instance [DecidableRel (· ⊑ · : α → α → _)] :
| .value _, .poison => .isFalse <| by simp
| .value a, .value b => decidable_of_decidable_of_iff (p := a ⊑ b) <| by simp


/-! ### if-then-else -/
section Ite
variable {c : Prop} [Decidable c] (a? b? : PoisonOr α) (a : α)

@[simp]
theorem if_then_poison_isRefinedBy_iff :
(if c then poison else a? : no_index _) ⊑ b? ↔ ¬c → a? ⊑ b? := by
split <;> simp [*]

@[simp]
theorem value_isRefinedBy_if_then_poison_iff :
value a ⊑ (if c then poison else b? : no_index _) ↔ ¬c ∧ (value a ⊑ b?) := by
split <;> simp [*]


/-!
Fallback theorems for generic if-then-else; other theorems should be preferred
as they give simpler rhs's for their specialized situations.
-/
theorem ite_isRefinedBy_iff {x? y? z? : PoisonOr α} :
ite c x? y? ⊑ z?
↔ let c := c
(c → x? ⊑ z?) ∧ (¬c → y? ⊑ z?) := by
split <;> simp [*]

theorem isRefinedBy_ite_iff {x? y? z? : PoisonOr α} :
x? ⊑ ite c y? z?
↔ let c := c
(c → x? ⊑ y?) ∧ (¬c → x? ⊑ z?) := by
split <;> simp [*]

end Ite

end Refinement


Expand Down
4 changes: 4 additions & 0 deletions SSA/Projects/InstCombine/ForStd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,8 @@ theorem Int.natCast_pred_of_pos (x : Nat) (h : 0 < x) :
(-·), Int.neg, Int.negOfNat, Int.subNatNat]
simp

theorem ofBool_eq_one_iff (b : Bool) :
ofBool b = 1#1 ↔ b = true := by
cases b <;> simp

end BitVec
6 changes: 6 additions & 0 deletions SSA/Projects/SLLVM/Dialect.lean
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@

import SSA.Projects.SLLVM.Dialect.Base
import SSA.Projects.SLLVM.Dialect.Refinement
import SSA.Projects.SLLVM.Dialect.Parser

import SSA.Projects.InstCombine.LLVM.PrettyEDSL
-- ^^ ensure we have the pretty EDSL macros available
37 changes: 35 additions & 2 deletions SSA/Projects/SLLVM/Dialect/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,49 @@ existing LLVM dialect, *but* it refines the semantics to include a proper model
of (side-effecting) UB.
-/

/-! ### Dialect definition -/

def SLLVM : Dialect where
Op := LLVM.Op
Ty := LLVM.Ty
m := UBOr

namespace SLLVM

instance : DecidableEq SLLVM.Op := by unfold SLLVM; infer_instance
instance : DecidableEq SLLVM.Ty := by unfold SLLVM; infer_instance
instance : ToString SLLVM.Op := by unfold SLLVM; infer_instance
instance : ToString SLLVM.Ty := by unfold SLLVM; infer_instance
instance : Repr SLLVM.Op := by unfold SLLVM; infer_instance
instance : Repr SLLVM.Ty := by unfold SLLVM; infer_instance
instance : Lean.ToExpr SLLVM.Op := by unfold SLLVM; infer_instance
instance : Lean.ToExpr SLLVM.Ty := by unfold SLLVM; infer_instance

instance : TyDenote SLLVM.Ty := by unfold SLLVM; infer_instance
instance : Monad SLLVM.m := by unfold SLLVM; infer_instance
instance : LawfulMonad SLLVM.m := by unfold SLLVM; infer_instance

open Qq in instance : DialectToExpr SLLVM where
toExprDialect := q(SLLVM)
toExprM := q(Id.{0})

@[match_pattern]
abbrev Ty.bitvec : Nat → SLLVM.Ty :=
InstCombine.LLVM.Ty.bitvec

@[simp_denote] theorem toType_bitvec : TyDenote.toType (Ty.bitvec w) = LLVM.IntW w := rfl

end SLLVM

/-! ### Signature -/

open InstCombine.LLVM.Op in
/-- The signature of each operation is the same as in LLVM. -/
instance : DialectSignature SLLVM where
signature op :=
{ DialectSignature.signature (d := LLVM) op with
effectKind := match op with
| udiv .. | sdiv .. => .impure
| udiv .. | sdiv .. | urem .. | srem .. => .impure
| _ => .pure
}

Expand All @@ -49,7 +80,9 @@ open InstCombine.LLVM.Op in
instance : DialectDenote SLLVM where
denote
| udiv _ flag => fun (x ::ₕ (y ::ₕ .nil)) _ => LeanMLIR.SLLVM.udiv x y flag
| sdiv _ flag => fun (x ::ₕ (y ::ₕ .nil)) _ => LeanMLIR.SLLVM.udiv x y flag
| sdiv _ flag => fun (x ::ₕ (y ::ₕ .nil)) _ => LeanMLIR.SLLVM.sdiv x y flag
| urem _ => fun (x ::ₕ (y ::ₕ .nil)) _ => LeanMLIR.SLLVM.urem x y
| srem _ => fun (x ::ₕ (y ::ₕ .nil)) _ => LeanMLIR.SLLVM.srem x y
| op => fun args .nil =>
EffectKind.liftEffect (EffectKind.pure_le _) <|
DialectDenote.denote (d := LLVM) op args .nil
Expand Down
129 changes: 129 additions & 0 deletions SSA/Projects/SLLVM/Dialect/Parser.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/-
Released under Apache 2.0 license as described in the file LICENSE.
-/
import Qq
import SSA.Core.MLIRSyntax.EDSL2
import SSA.Core.MLIRSyntax.Transform.Utils

import SSA.Projects.SLLVM.Dialect.Base

namespace LeanMLIR.SLLVM

-- open Qq Lean Meta Elab.Term Elab Command
open MLIR
open MLIR.AST

private abbrev ReaderM := MLIR.AST.ReaderM SLLVM

def mkTy : MLIRType 0 → ExceptM SLLVM SLLVM.Ty
| .int .Signless w => return Ty.bitvec w.toConcrete
| _ => throw .unsupportedType

instance : TransformTy SLLVM 0 where
mkTy := mkTy


def getOutputWidth (opStx : MLIR.AST.Op φ) (op : String) :
Except TransformError (Width φ) := do
match opStx.res with
| res::[] =>
match res.2 with
| .int _ w => pure w
| _ => throw <| .generic s!"The operation {op} must output an integer type"
| _ => throw <| .generic s!"The operation {op} must have a single output"

/-- Given a variable of arbitrary type, return its width.

This relies on the fact that `bitvec _` is the only type currently modelled.
In future, this might start throwing errors, if the type is not a bitvec.
-/
def getVarWidth {Γ : Ctxt SLLVM.Ty} : (Σ t, Γ.Var t) → Nat
| ⟨Ty.bitvec w, _⟩ => w

def parseOverflowFlags (op : AST.Op 0) : ReaderM LLVM.NoWrapFlags :=
match op.getAttr? "overflowFlags" with
| .none => return {}
| .some y => match y with
| .opaque_ "llvm.overflow" "nsw" => return ⟨true, false⟩
| .opaque_ "llvm.overflow" "nuw" => return ⟨false, true⟩
| .list [.opaque_ "llvm.overflow" "nuw", .opaque_ "llvm.overflow" "nsw"]
| .list [.opaque_ "llvm.overflow" "nsw", .opaque_ "llvm.overflow" "nuw"] =>
return ⟨true, true⟩
| .opaque_ "llvm.overflow" s => throw <| .generic s!"The overflow flag {s} not allowed. \
We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| _ => throw <| .generic s!"Unrecognised overflow flag found: {MLIR.AST.docAttrVal y}. \
We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"

open InstCombine (MOp) in
instance : TransformExpr SLLVM 0 where
mkExpr (Γ : Ctxt SLLVM.Ty) (opStx : MLIR.AST.Op 0) := do
let args ← opStx.parseArgs Γ

/- `binW` is the (lazily computed) width, assuming a binary operation -/
let binW := do
-- ^^ NOTE: `binW` is bound with := rather than ← on purpose, to ensure laziness
let args ← args.assumeArity 2
return getVarWidth args[0]

/- `unW` is the (lazily computed) width, assuming an unary operation -/
let unW := do
let args ← args.assumeArity 1
return getVarWidth args[0]

let mkExprOf := opStx.mkExprOf (args? := args) Γ
match opStx.name with
-- Ternary Operations
| "llvm.select" =>
let args ← args.assumeArity 3
let w := getVarWidth args[1]
mkExprOf <| MOp.select w
-- Binary Operations
| "llvm.and" => mkExprOf <| MOp.and (← binW)
| "llvm.or" => mkExprOf <| MOp.or (← binW) ⟨opStx.hasAttr "isDisjoint"⟩
| "llvm.xor" => mkExprOf <| MOp.xor (← binW)
| "llvm.urem" => mkExprOf <| MOp.urem (← binW)
| "llvm.srem" => mkExprOf <| MOp.srem (← binW)
| "llvm.lshr" => mkExprOf <| MOp.lshr (← binW) ⟨opStx.hasAttr "isExact"⟩
| "llvm.ashr" => mkExprOf <| MOp.ashr (← binW) ⟨opStx.hasAttr "isExact"⟩
| "llvm.sdiv" => mkExprOf <| MOp.sdiv (← binW) ⟨opStx.hasAttr "isExact"⟩
| "llvm.udiv" => mkExprOf <| MOp.udiv (← binW) ⟨opStx.hasAttr "isExact"⟩
| "llvm.shl" => mkExprOf <| MOp.shl (← binW) (← parseOverflowFlags opStx)
| "llvm.add" => mkExprOf <| MOp.add (← binW) (← parseOverflowFlags opStx)
| "llvm.mul" => mkExprOf <| MOp.mul (← binW) (← parseOverflowFlags opStx)
| "llvm.sub" => mkExprOf <| MOp.sub (← binW) (← parseOverflowFlags opStx)
| "llvm.icmp.eq" => mkExprOf <| MOp.icmp .eq (← binW)
| "llvm.icmp.ne" => mkExprOf <| MOp.icmp .ne (← binW)
| "llvm.icmp.ugt" => mkExprOf <| MOp.icmp .ugt (← binW)
| "llvm.icmp.uge" => mkExprOf <| MOp.icmp .uge (← binW)
| "llvm.icmp.ult" => mkExprOf <| MOp.icmp .ult (← binW)
| "llvm.icmp.ule" => mkExprOf <| MOp.icmp .ule (← binW)
| "llvm.icmp.sgt" => mkExprOf <| MOp.icmp .sgt (← binW)
| "llvm.icmp.sge" => mkExprOf <| MOp.icmp .sge (← binW)
| "llvm.icmp.slt" => mkExprOf <| MOp.icmp .slt (← binW)
| "llvm.icmp.sle" => mkExprOf <| MOp.icmp .sle (← binW)
-- Unary Operations
| "llvm.not" => mkExprOf <| MOp.not (← unW)
| "llvm.neg" => mkExprOf <| MOp.neg (← unW)
| "llvm.copy" => mkExprOf <| MOp.copy (← unW)
| "llvm.zext" => mkExprOf <| MOp.zext (← unW) (← getOutputWidth opStx "zext") ⟨ opStx.hasAttr "nonNeg" ⟩
| "llvm.sext" => mkExprOf <| MOp.sext (← unW) (← getOutputWidth opStx "sext")
| "llvm.trunc" => mkExprOf <| MOp.trunc (← unW) (← getOutputWidth opStx "trunc") (← parseOverflowFlags opStx)
-- Constant
| "llvm.mlir.constant" =>
let ⟨val, ty⟩ ← opStx.getIntAttr "value"
let opTy@(.bitvec w) ← mkTy ty
mkExprOf <| MOp.const w val
-- Fallback
| opName => throw <| .unsupportedOp opName

instance : TransformReturn SLLVM 0 where
mkReturn (Γ : Ctxt SLLVM.Ty) (opStx : MLIR.AST.Op 0) := do
if opStx.name ≠ "llvm.return" then
throw <| .unsupportedOp s!"Tried to build return out of non-return statement {opStx.name}"
else
let args ← (← opStx.parseArgs Γ).assumeArity 1
let ⟨ty, v⟩ := args[0]
return ⟨.pure, ty, Com.ret v⟩

elab "[sllvm| " reg:mlir_region "]" : term =>
SSA.elabIntoCom' reg SLLVM
33 changes: 33 additions & 0 deletions SSA/Projects/SLLVM/Dialect/Refinement.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import SSA.Projects.SLLVM.Dialect.Base

/-!
# Refinement relation for SLLVM dialect
-/

namespace LeanMLIR.SLLVM

open InstCombine.LLVM.Ty (bitvec)

scoped instance : Refinement (BitVec w) := .ofEq
@[simp, simp_llvm_split] theorem bv_isRefinedBy_iff (x y : BitVec w) : x ⊑ y ↔ x = y := by rfl
-- ^^ declare that for pure bitvectors, refinement is just equality

instance (w) : Refinement (UBOr <| LLVM.IntW w) where
IsRefinedBy (x y : PoisonOr (PoisonOr _)) := x ⊑ y

open HRefinement in
@[simp, simp_denote]
theorem isRefinedBy_iff (x y : UBOr <| LLVM.IntW w) :
x ⊑ y ↔ @IsRefinedBy (PoisonOr <| PoisonOr _) (PoisonOr <| PoisonOr _) _ x y := by
rfl

instance : DialectHRefinement SLLVM SLLVM where
IsRefinedBy := @fun
| bitvec v, bitvec w, (x : PoisonOr (PoisonOr _)), (y : PoisonOr (PoisonOr _)) =>
∃ h : v = w, x ⊑ h ▸ y

@[simp, simp_denote]
theorem dialect_isRefinedBy_iff_of_width_eq (x y : SLLVM.m ⟦bitvec w⟧) :
DialectHRefinement.IsRefinedBy x y
↔ HRefinement.IsRefinedBy (α := UBOr <| LLVM.IntW w) (β := UBOr <| LLVM.IntW w) x y := by
simp [DialectHRefinement.IsRefinedBy]
Loading
Loading