diff --git a/KLR/Trace/ISA.lean b/KLR/Trace/ISA.lean index 3401ebcd..bed39319 100644 --- a/KLR/Trace/ISA.lean +++ b/KLR/Trace/ISA.lean @@ -170,16 +170,40 @@ nki builtin.isa.nc_transpose match engine with | .pe => let N := data.shapePure.freeDims.getLast! - let id : TensorRef := <- match <- lookup_global? (.num `identity 0) with - | some (.access acc) => return .abstract acc - | some _ => throw "identity has wrong type" - | none => throw "identity not defined" - let idName : TensorName <- match id with - | .abstract $ .simple t => pure t - | .abstract $ .basic t => pure t.tensor - | .abstract $ .pattern t => pure t.tensor - | _ => throw "Expected identity matrix to be a ref" - let idSlice : TensorRef := .abstract $ .basic $ <- AccessBasic.make idName [ + let dtype := data.tensor.dtype + -- Create identity matrix inline with the data's dtype + -- using memset(0) + affine_select, matching beta3 approach. + let idShape := Core.Shape.mk 128 [128] + let idName := (<- genName).toString + let idAddr : Address := { + name := idName, + memory := .sbuf, + parSize := 128 + freeSize := 128 * dtype.size + } + let idTensor <- Core.TensorName.make idName dtype idShape (some idAddr) (<- flags.address_rotation) + let idRef : TensorRef := .abstract (.simple idTensor) + -- Zero the identity tensor + Trace.add_stmt $ .oper (.memSet { + dst := idRef, + value := .float 0.0, + dtype := dtype, + engine := .unassigned + }) (<- genLabel `memset_id) + -- Write 1.0 on the diagonal using affine_select: + -- pattern [[0,1],[0,1],[0,1],[1,128]] with channel_multiplier=-1 + -- produces (free_idx - partition_idx); cmp_op=not_equal keeps + -- on_true_tile (zeros) off-diagonal, writes on_false_value (1.0) + -- on the diagonal where the expression equals 0. + Trace.add_stmt $ .oper (.ncAffineSelect { + dst := idRef, + pred := ⟨0, [⟨0, 1, 0⟩, ⟨0, 1, 0⟩, ⟨0, 1, 0⟩, ⟨1, 128, 0⟩], -1⟩, + onTrueTile := idRef, + onFalseValue := .float 1.0, + dtype := some dtype, + cmpOp := .not_equal, + }) (<- genLabel `affsel_id) + let idSlice : TensorRef := .abstract $ .basic $ <- AccessBasic.make idTensor [ .slice $ Slice.make! 0 N 1, .slice $ Slice.make! 0 N 1 ] @@ -187,7 +211,7 @@ nki builtin.isa.nc_transpose dst := .abstract dst, stationary := idSlice, moving := .abstract data, - isStationaryOneZero := false, + isStationaryOneZero := true, isMovingZero := false, isTranspose := true, tilePosition := [], diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index 2fa4b9ff..6363faa1 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -595,7 +595,6 @@ partial def lowerRes (t: Term) : Trace (List Core.Access) := do def traceKernel (k : Kernel) : Trace Core.Kernel := do let _ <- beginBlock (<- genLabel `main) - addId globals k flags k.flags match k.funs.find? fun f => f.name == k.entry with