diff --git a/KLR/NKI/Simplify.lean b/KLR/NKI/Simplify.lean index c473fc4f..9aec2128 100644 --- a/KLR/NKI/Simplify.lean +++ b/KLR/NKI/Simplify.lean @@ -509,7 +509,7 @@ private def edgeName (e : Expr) : Simplify String := | .value (.string s) => return s | _ => throw "Schedule edge name must be a string" -private def edge (py : Python.Expr) : Simplify Edges := do +private def edge (py : Python.Expr) : Simplify (List Edges) := do let e <- expr py match e.expr with | .tuple [e1, e2] => @@ -517,7 +517,11 @@ private def edge (py : Python.Expr) : Simplify Edges := do let e2 <- match e2.expr with | .list es => es.mapM edgeName | _ => pure [<- edgeName e2] - return .mk e1 e2 + -- Edges are represented in Python like (x, [y1, y2, ...]). + -- This means that x depends on y1, y2, ... and must be scheduled after each of them. + -- However components consuming KLIR expect the reverse: y1, y2, ... each depend on x. + -- Reverse user-provided scheduling edges to address this. + return e2.map (.mk · [e1]) | _ => throw "Schedule edge must be a pair" private def flags (py : Python.Expr) : Simplify (String × Value) := do @@ -549,6 +553,6 @@ def simplify (py : Python.Kernel) : Simplify Kernel := do globals := <- kwargs py.globals arch := py.arch grid := py.grid - edges := <- py.scheduleEdges.mapM edge + edges := <- py.scheduleEdges.foldlM (fun acc e => return (<- edge e) ++ acc) [] flags := <- py.flags.mapM flags } diff --git a/KLR/Trace/Types.lean b/KLR/Trace/Types.lean index 294d95a0..9719b270 100644 --- a/KLR/Trace/Types.lean +++ b/KLR/Trace/Types.lean @@ -355,7 +355,7 @@ private def swapLast (engine : Engine) (name : String) (last : LastInst) : LastI private def updateLast (engine : Engine) (name : String) (state : State) : State := if state.noReorderDepth == 0 then state else let (last, names) := swapLast engine name state.lastInst - let edges := names.map (name, ·) + let edges := names.map (·, name) { state with lastInst := last, edges := edges ++ state.edges } def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do