Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions KLR/NKI/Simplify.lean
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,19 @@ private def edgeName (e : Expr) : Simplify String :=
| .value (.string s) => return s
| _ => throw "Schedule edge name must be a string"

private def edge (py : Python.Expr) : Simplify Edges := do
private def edge (py : Python.Expr) : Simplify (List Edges) := do
let e <- expr py
match e.expr with
| .tuple [e1, e2] =>
let e1 <- edgeName e1
let e2 <- match e2.expr with
| .list es => es.mapM edgeName
| _ => pure [<- edgeName e2]
return .mk e1 e2
-- Edges are represented in Python like (x, [y1, y2, ...]).
-- This means that x depends on y1, y2, ... and must be scheduled after each of them.
-- However components consuming KLIR expect the reverse: y1, y2, ... each depend on x.
-- Reverse user-provided scheduling edges to address this.
return e2.map (.mk · [e1])
| _ => throw "Schedule edge must be a pair"

private def flags (py : Python.Expr) : Simplify (String × Value) := do
Expand Down Expand Up @@ -549,6 +553,6 @@ def simplify (py : Python.Kernel) : Simplify Kernel := do
globals := <- kwargs py.globals
arch := py.arch
grid := py.grid
edges := <- py.scheduleEdges.mapM edge
edges := <- py.scheduleEdges.foldlM (fun acc e => return (<- edge e) ++ acc) []
flags := <- py.flags.mapM flags
}
2 changes: 1 addition & 1 deletion KLR/Trace/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ private def swapLast (engine : Engine) (name : String) (last : LastInst) : LastI
private def updateLast (engine : Engine) (name : String) (state : State) : State :=
if state.noReorderDepth == 0 then state else
let (last, names) := swapLast engine name state.lastInst
let edges := names.map (name, ·)
let edges := names.map (·, name)
Comment thread
kerrijoe-aws marked this conversation as resolved.
{ state with lastInst := last, edges := edges ++ state.edges }

def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do
Expand Down
Loading