From 55b8d4a4a821d4ce57398380fe8d476c5733a949 Mon Sep 17 00:00:00 2001 From: Joe Kerrigan Date: Wed, 25 Feb 2026 23:28:38 +0000 Subject: [PATCH] fix: reverse order of no_reorder dependency edges PR #604 altered the semantics of no_reorder so that instructions executing on distinct engines can still be scheduled in parallel. Unfortunately, it also reversed the order of dependency edges emitted by no_reorder. Downstream components interpret an edge (a, b) as 'b depends on a', not the other way around. Update to preserve existing convention, and ensure user-provided edges follow this convention as well. --- KLR/NKI/Simplify.lean | 10 +++++++--- KLR/Trace/Types.lean | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/KLR/NKI/Simplify.lean b/KLR/NKI/Simplify.lean index c473fc4f..9aec2128 100644 --- a/KLR/NKI/Simplify.lean +++ b/KLR/NKI/Simplify.lean @@ -509,7 +509,7 @@ private def edgeName (e : Expr) : Simplify String := | .value (.string s) => return s | _ => throw "Schedule edge name must be a string" -private def edge (py : Python.Expr) : Simplify Edges := do +private def edge (py : Python.Expr) : Simplify (List Edges) := do let e <- expr py match e.expr with | .tuple [e1, e2] => @@ -517,7 +517,11 @@ private def edge (py : Python.Expr) : Simplify Edges := do let e2 <- match e2.expr with | .list es => es.mapM edgeName | _ => pure [<- edgeName e2] - return .mk e1 e2 + -- Edges are represented in Python like (x, [y1, y2, ...]). + -- This means that x depends on y1, y2, ... and must be scheduled after each of them. + -- However components consuming KLIR expect the reverse: y1, y2, ... each depend on x. + -- Reverse user-provided scheduling edges to address this. + return e2.map (.mk · [e1]) | _ => throw "Schedule edge must be a pair" private def flags (py : Python.Expr) : Simplify (String × Value) := do @@ -549,6 +553,6 @@ def simplify (py : Python.Kernel) : Simplify Kernel := do globals := <- kwargs py.globals arch := py.arch grid := py.grid - edges := <- py.scheduleEdges.mapM edge + edges := <- py.scheduleEdges.foldlM (fun acc e => return (<- edge e) ++ acc) [] flags := <- py.flags.mapM flags } diff --git a/KLR/Trace/Types.lean b/KLR/Trace/Types.lean index 294d95a0..9719b270 100644 --- a/KLR/Trace/Types.lean +++ b/KLR/Trace/Types.lean @@ -355,7 +355,7 @@ private def swapLast (engine : Engine) (name : String) (last : LastInst) : LastI private def updateLast (engine : Engine) (name : String) (state : State) : State := if state.noReorderDepth == 0 then state else let (last, names) := swapLast engine name state.lastInst - let edges := names.map (name, ·) + let edges := names.map (·, name) { state with lastInst := last, edges := edges ++ state.edges } def add_stmt (stmt : Pos -> Stmt) : Trace Unit := do