From f913935f88de95465e0ae2d02b5528556e8dc1aa Mon Sep 17 00:00:00 2001 From: Pavel Potapov Date: Fri, 6 Mar 2026 16:16:07 +0000 Subject: [PATCH] Init identity matrix and jump to main --- KLR/Trace/Types.lean | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/KLR/Trace/Types.lean b/KLR/Trace/Types.lean index 33a0dbc2..8ddc5ff5 100644 --- a/KLR/Trace/Types.lean +++ b/KLR/Trace/Types.lean @@ -429,7 +429,13 @@ def addId : Trace Unit := do throw "identity already initialized" let dtype := .int8 let shape := Core.Shape.mk 128 [128] - let tensorName <- Core.TensorName.make idName.toString dtype shape none (<- flags.address_rotation) + let idSbufAddr : Address := { + name := idName.toString, + memory := .sbuf, + parSize := 128 + freeSize := 128 + } + let tensorName <- Core.TensorName.make idName.toString dtype shape idSbufAddr (<- flags.address_rotation) let id : KLR.Core.TensorRef := .abstract (.simple tensorName) let pos : Pos := { line := 0, column := 0 } let hbmInitName := <-genName @@ -452,8 +458,18 @@ def addId : Trace Unit := do }) none pos let lbl := (<- genLabel `init) let idTensor := identity 128 + -- Jump from init block to the first main block so the CFG has a single exit. + let mainLabel := match (<- get).label with + | some l => l + | none => "main.1" -- fallback; should not happen + let jmpName := (<- genLabel `jmp) + let jmpStmt := Core.Stmt.oper (.cmpBranch { + reg1 := "", reg2 := "", imm := 0, + op := BrCmpOp.always, + trueLabel := mainLabel, falseLabel := "" + }) jmpName pos modify fun s => { s with - body := #[Block.mk lbl false [initStmt]] ++ s.body, + body := #[Block.mk lbl false [initStmt, jmpStmt]] ++ s.body, sharedConstants := s.sharedConstants.push (hbmInitName.toString, idTensor) } extend_global idName (.access (.simple tensorName))