diff --git a/KLR/Core/Operators.lean b/KLR/Core/Operators.lean index 65a03bed..84246b24 100644 --- a/KLR/Core/Operators.lean +++ b/KLR/Core/Operators.lean @@ -398,6 +398,7 @@ structure NcDmaCopy where oobMode : DmaBounds dgeMode : Nat uniqueIndices : Bool + priority : Option Nat engine : Engine deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp @@ -415,6 +416,7 @@ structure DmaTranspose where dtype : Option Dtype dgeMode : Nat oobMode : DmaBounds + priority : Option Nat deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp instance : MapTensorRefs DmaTranspose where @@ -824,6 +826,8 @@ structure TensorScalarReduce where dtype : Option Dtype reduceOp : Option AluOp reduceRes : TensorRef + reduceCmd : AccumCmd + reduceInit : Option Immediate deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp instance : MapTensorRefs TensorScalarReduce where @@ -1037,6 +1041,7 @@ structure CollectiveOp where sourceTargetPairs : Option (List (List Int)) := none channel_id : Option Int := none num_channels : Option Int := none + priority : Option Nat := none deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp instance : MapTensorRefs CollectiveOp where diff --git a/KLR/Trace/ISA.lean b/KLR/Trace/ISA.lean index bed39319..dc3b80c8 100644 --- a/KLR/Trace/ISA.lean +++ b/KLR/Trace/ISA.lean @@ -455,6 +455,8 @@ nki builtin.isa.tensor_scalar_reduce (reduce_op : AluOp) (reduce_res : Access) (reverse0 : Bool := false) + (reduce_cmd : AccumCmd := .Zero) + (reduce_init : Option Immediate := none) (mask : Option Immediate := none) (name : Option String := none) := do if mask.isSome then throw maskNotSupported @@ -469,6 +471,8 @@ nki builtin.isa.tensor_scalar_reduce reduceRes := .abstract reduce_res reverse0 := reverse0 dtype := dst.tensor.dtype + reduceCmd := reduce_cmd + reduceInit := reduce_init }) name return .none @@ -745,6 +749,7 @@ nki builtin.isa.dma_copy (dge_mode : Nat := 0) (unique_indices : Bool := false) (engine : Engine := .unassigned) + (priority : Option Nat := .none) (name : Option String := none) := do if mask.isSome then throw maskNotSupported let op : DgeComputeOp := <- match dst_rmw_op with @@ -763,6 +768,7 @@ nki builtin.isa.dma_copy | _ => .skip, dgeMode := dge_mode, uniqueIndices := unique_indices + priority engine }) name return .none @@ -775,6 +781,7 @@ nki builtin.isa.dma_transpose (mask : Option Immediate := none) (dge_mode : Nat := 0) (oob_mode : Nat := 0) + (priority : Option Nat := .none) (name : Option String := none) := do if mask.isSome then throw maskNotSupported if oob_mode > 1 then throw "unsupported oob mode" @@ -792,6 +799,7 @@ nki builtin.isa.dma_transpose | 0 => .error | 1 => .skip | _ => .error, + priority }) name return .none @@ -999,12 +1007,14 @@ nki builtin.isa.all_reduce (srcs : List Access) (dsts : List Access) (replica_group: Sum String (List (List Int))) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.allReduce { dsts := dsts.map .abstract srcs := srcs.map .abstract op := some op replicaGroup := replicaGroupFromSum replica_group + priority }) name return .none @@ -1014,12 +1024,14 @@ nki builtin.isa.all_gather (dsts : List Access) (replica_group : Sum String (List (List Int))) (concat_dim : Int) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.allGather { dsts := dsts.map .abstract srcs := srcs.map .abstract replicaGroup := replicaGroupFromSum replica_group concatDim := some concat_dim + priority }) name return .none @@ -1030,6 +1042,7 @@ nki builtin.isa.reduce_scatter (dsts : List Access) (replica_group : Sum String (List (List Int))) (concat_dim : Int) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.reduceScatter { dsts := dsts.map .abstract @@ -1037,6 +1050,7 @@ nki builtin.isa.reduce_scatter op := some op replicaGroup := replicaGroupFromSum replica_group concatDim := some concat_dim + priority }) name return .none @@ -1046,12 +1060,14 @@ nki builtin.isa.all_to_all (dsts : List Access) (replica_group : Sum String (List (List Int))) (concat_dim : Int) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.allToAll { dsts := dsts.map .abstract srcs := srcs.map .abstract replicaGroup := replicaGroupFromSum replica_group concatDim := some concat_dim + priority }) name return .none @@ -1060,11 +1076,13 @@ nki builtin.isa.collective_permute (src : Access) (dst : Access) (source_target_pairs: List (List Int)) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.collectivePermute { dsts := [.abstract dst] srcs := [.abstract src] sourceTargetPairs := some source_target_pairs + priority }) name return .none @@ -1075,6 +1093,7 @@ nki builtin.isa.collective_permute_implicit (replica_group : Sum String (List (List Int))) (channel_id : Int) (num_channels : Int := 1) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.collectivePermuteImplicit { dsts := [.abstract dst] @@ -1082,6 +1101,7 @@ nki builtin.isa.collective_permute_implicit replicaGroup := replicaGroupFromSum replica_group channel_id := some channel_id num_channels := some num_channels + priority }) name return .none @@ -1094,6 +1114,7 @@ nki builtin.isa.collective_permute_implicit_reduce (replica_group : Sum String (List (List Int))) (channel_id : Int) (num_channels : Int := 1) + (priority : Option Nat := none) (name : Option String := none) := do Trace.add_stmt $ .oper (.collectivePermuteImplicitReduce { dsts := [.abstract dst] @@ -1102,6 +1123,7 @@ nki builtin.isa.collective_permute_implicit_reduce replicaGroup := replicaGroupFromSum replica_group channel_id := some channel_id num_channels := some num_channels + priority }) name return .none diff --git a/KLR/Trace/Types.lean b/KLR/Trace/Types.lean index 8ddc5ff5..cd8f77eb 100644 --- a/KLR/Trace/Types.lean +++ b/KLR/Trace/Types.lean @@ -454,6 +454,7 @@ def addId : Trace Unit := do oobMode := .skip, dgeMode := 0, uniqueIndices := false + priority := .none engine := .unassigned }) none pos let lbl := (<- genLabel `init) diff --git a/interop/klr/NKI.asdl b/interop/klr/NKI.asdl index 5ecf8209..287c2875 100644 --- a/interop/klr/NKI.asdl +++ b/interop/klr/NKI.asdl @@ -217,7 +217,7 @@ AffineSelect = (TensorRef dst, TensorRef src, AffineSelectCmp fillMode, Reg fill DmaCopy = (TensorRef dst, TensorRef src, DgeComputeOp compute_op, DmaBounds dstBoundsCheck, DmaBounds srcBoundsCheck) -DmaTranspose = (TensorRef dst, TensorRef src, TransposeOps axes, Dtype? dtype, Nat dgeMode, DmaBounds oobMode) +DmaTranspose = (TensorRef dst, TensorRef src, TransposeOps axes, Dtype? dtype, Nat dgeMode, DmaBounds oobMode, Nat? priority) Transpose = (TensorRef dst, TensorRef src, Dtype? dtype, Engine engine) @@ -267,7 +267,7 @@ TensorTensor = (TensorRef dst, TensorRef src0, TensorRef src1, AluOp op, Dtype? NcMatMul = (TensorRef dst, TensorRef stationary, TensorRef moving, Bool isStationaryOneZero, Bool isMovingZero, Bool isTranspose, Nat* tilePosition, Nat* tileSize, MatmulPerfMode perfMode) -TensorScalarReduce = (TensorRef dst, TensorRef src, Operand operand0, AluOp op0, Bool reverse0, Dtype? dtype, AluOp? reduceOp, TensorRef reduceRes) +TensorScalarReduce = (TensorRef dst, TensorRef src, Operand operand0, AluOp op0, Bool reverse0, Dtype? dtype, AluOp? reduceOp, TensorRef reduceRes, AccumCmd reduceCmd, Immediate? reduceInit) TensorPartitionReduce = (TensorRef dst, AluOp op, TensorRef data, Dtype? dtype) @@ -275,7 +275,7 @@ NcActivate = (TensorRef dst, TensorRef src, AccumCmd accumulatorCmd, ActivationF NcAffineSelect = (TensorRef dst, DataPattern pred, TensorRef onTrueTile, Immediate onFalseValue, Dtype? dtype, AluOp cmpOp) -NcDmaCopy = (TensorRef dst, TensorRef src, DgeComputeOp compute_op, DmaBounds oobMode, Nat dgeMode, Bool uniqueIndices, Engine engine) +NcDmaCopy = (TensorRef dst, TensorRef src, DgeComputeOp compute_op, DmaBounds oobMode, Nat dgeMode, Bool uniqueIndices, Nat? priority, Engine engine) NcLocalGather = (TensorRef dst, TensorRef src, TensorRef index, Immediate numElemPerIdx, Immediate? numValidIndicies) @@ -304,7 +304,7 @@ ReplicaGroup = | literal(Int** groups) -CollectiveOp = (TensorRef* dsts, TensorRef* srcs, AluOp? op, ReplicaGroup replicaGroup, Int? concatDim, Int**? sourceTargetPairs, Int? channel_id, Int? num_channels) +CollectiveOp = (TensorRef* dsts, TensorRef* srcs, AluOp? op, ReplicaGroup replicaGroup, Int? concatDim, Int**? sourceTargetPairs, Int? channel_id, Int? num_channels, Nat? priority) RankId = (String dst) diff --git a/interop/klr/klir_ast.hpp b/interop/klr/klir_ast.hpp index 6cbe1cd7..3c71d008 100644 --- a/interop/klr/klir_ast.hpp +++ b/interop/klr/klir_ast.hpp @@ -541,6 +541,7 @@ struct DmaTranspose final { Option dtype; Nat dgeMode; Ptr oobMode; + Option priority; }; struct Transpose final { @@ -737,6 +738,8 @@ struct TensorScalarReduce final { Option dtype; Option reduceOp; Ptr reduceRes; + AccumCmd reduceCmd; + Option> reduceInit; }; struct TensorPartitionReduce final { @@ -774,6 +777,7 @@ struct NcDmaCopy final { Ptr oobMode; Nat dgeMode; Bool uniqueIndices; + Option priority; Engine engine; }; @@ -910,6 +914,7 @@ struct CollectiveOp final { Option>> sourceTargetPairs; Option channel_id; Option num_channels; + Option priority; }; struct RankId final { diff --git a/interop/klr/klir_pretty_print.cpp b/interop/klr/klir_pretty_print.cpp index a6fee5f2..57057e83 100644 --- a/interop/klr/klir_pretty_print.cpp +++ b/interop/klr/klir_pretty_print.cpp @@ -1372,6 +1372,13 @@ std::string to_string(DmaTranspose &DmaTransposeInstance) { result += ", "; result += "oobMode="; result += to_string(*(DmaTransposeInstance.oobMode.get())); + result += ", "; + result += "priority="; + if (DmaTransposeInstance.priority.has_value()) { + result += std::to_string(DmaTransposeInstance.priority.value()); + } else { + result += "None"; + } result += ")"; return result; }; @@ -2000,6 +2007,16 @@ std::string to_string(TensorScalarReduce &TensorScalarReduceInstance) { result += ", "; result += "reduceRes="; result += to_string(*(TensorScalarReduceInstance.reduceRes.get())); + result += ", "; + result += "reduceCmd="; + result += to_string(TensorScalarReduceInstance.reduceCmd); // mapped from enum + result += ", "; + result += "reduceInit="; + if (TensorScalarReduceInstance.reduceInit.has_value()) { + result += to_string(*(TensorScalarReduceInstance.reduceInit.value().get())); + } else { + result += "None"; + } result += ")"; return result; }; @@ -2127,6 +2144,13 @@ std::string to_string(NcDmaCopy &NcDmaCopyInstance) { result += "uniqueIndices="; result += std::to_string(NcDmaCopyInstance.uniqueIndices); result += ", "; + result += "priority="; + if (NcDmaCopyInstance.priority.has_value()) { + result += std::to_string(NcDmaCopyInstance.priority.value()); + } else { + result += "None"; + } + result += ", "; result += "engine="; result += to_string(NcDmaCopyInstance.engine); // mapped from enum result += ")"; @@ -2659,6 +2683,13 @@ std::string to_string(CollectiveOp &CollectiveOpInstance) { } else { result += "None"; } + result += ", "; + result += "priority="; + if (CollectiveOpInstance.priority.has_value()) { + result += std::to_string(CollectiveOpInstance.priority.value()); + } else { + result += "None"; + } result += ")"; return result; }; diff --git a/interop/klr/klir_serde.cpp b/interop/klr/klir_serde.cpp index 7c6bf26f..1705a4ef 100644 --- a/interop/klr/klir_serde.cpp +++ b/interop/klr/klir_serde.cpp @@ -1791,7 +1791,7 @@ bool DmaCopy_ser(FILE *out, const Ptr &value) { } bool DmaTranspose_ser(FILE *out, const Ptr &value) { - if (!serialize_tag(out, 154, 0, 6)) + if (!serialize_tag(out, 154, 0, 7)) return false; if (!TensorRef_ser(out, value->dst)) return false; @@ -1805,6 +1805,8 @@ bool DmaTranspose_ser(FILE *out, const Ptr &value) { return false; if (!DmaBounds_ser(out, value->oobMode)) return false; + if (!Option_Nat_ser(out, value->priority)) + return false; return true; } @@ -2179,7 +2181,7 @@ bool NcMatMul_ser(FILE *out, const Ptr &value) { } bool TensorScalarReduce_ser(FILE *out, const Ptr &value) { - if (!serialize_tag(out, 181, 0, 8)) + if (!serialize_tag(out, 181, 0, 10)) return false; if (!TensorRef_ser(out, value->dst)) return false; @@ -2197,6 +2199,10 @@ bool TensorScalarReduce_ser(FILE *out, const Ptr &value) { return false; if (!TensorRef_ser(out, value->reduceRes)) return false; + if (!AccumCmd_ser(out, value->reduceCmd)) + return false; + if (!Option_Immediate_ser(out, value->reduceInit)) + return false; return true; } @@ -2258,7 +2264,7 @@ bool NcAffineSelect_ser(FILE *out, const Ptr &value) { } bool NcDmaCopy_ser(FILE *out, const Ptr &value) { - if (!serialize_tag(out, 153, 0, 7)) + if (!serialize_tag(out, 153, 0, 8)) return false; if (!TensorRef_ser(out, value->dst)) return false; @@ -2272,6 +2278,8 @@ bool NcDmaCopy_ser(FILE *out, const Ptr &value) { return false; if (!Bool_ser(out, value->uniqueIndices)) return false; + if (!Option_Nat_ser(out, value->priority)) + return false; if (!Engine_ser(out, value->engine)) return false; return true; @@ -2527,7 +2535,7 @@ bool ReplicaGroup_ser(FILE *out, const Ptr &value) { } bool CollectiveOp_ser(FILE *out, const Ptr &value) { - if (!serialize_tag(out, 199, 0, 8)) + if (!serialize_tag(out, 199, 0, 9)) return false; if (!List_TensorRef_ser(out, value->dsts)) return false; @@ -2545,6 +2553,8 @@ bool CollectiveOp_ser(FILE *out, const Ptr &value) { return false; if (!Option_Int_ser(out, value->num_channels)) return false; + if (!Option_Nat_ser(out, value->priority)) + return false; return true; } @@ -5511,9 +5521,9 @@ Ptr DmaTranspose_des(FILE *in) { msg << "Could not find tag, expecting DmaTranspose:154,0"; throw std::runtime_error(msg.str()); } - if (t != 154 || c != 0 || l != 6) { + if (t != 154 || c != 0 || l != 7) { std::ostringstream msg; - msg << "Expecting DmaTranspose:(154,0,6)"; + msg << "Expecting DmaTranspose:(154,0,7)"; msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; throw std::runtime_error(msg.str()); } @@ -5524,6 +5534,7 @@ Ptr DmaTranspose_des(FILE *in) { x->dtype = Option_Dtype_des(in); x->dgeMode = Nat_des(in); x->oobMode = DmaBounds_des(in); + x->priority = Option_Nat_des(in); return x; } @@ -6055,9 +6066,9 @@ Ptr TensorScalarReduce_des(FILE *in) { msg << "Could not find tag, expecting TensorScalarReduce:181,0"; throw std::runtime_error(msg.str()); } - if (t != 181 || c != 0 || l != 8) { + if (t != 181 || c != 0 || l != 10) { std::ostringstream msg; - msg << "Expecting TensorScalarReduce:(181,0,8)"; + msg << "Expecting TensorScalarReduce:(181,0,10)"; msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; throw std::runtime_error(msg.str()); } @@ -6070,6 +6081,8 @@ Ptr TensorScalarReduce_des(FILE *in) { x->dtype = Option_Dtype_des(in); x->reduceOp = Option_AluOp_des(in); x->reduceRes = TensorRef_des(in); + x->reduceCmd = AccumCmd_des(in); + x->reduceInit = Option_Immediate_des(in); return x; } @@ -6150,9 +6163,9 @@ Ptr NcDmaCopy_des(FILE *in) { msg << "Could not find tag, expecting NcDmaCopy:153,0"; throw std::runtime_error(msg.str()); } - if (t != 153 || c != 0 || l != 7) { + if (t != 153 || c != 0 || l != 8) { std::ostringstream msg; - msg << "Expecting NcDmaCopy:(153,0,7)"; + msg << "Expecting NcDmaCopy:(153,0,8)"; msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; throw std::runtime_error(msg.str()); } @@ -6163,6 +6176,7 @@ Ptr NcDmaCopy_des(FILE *in) { x->oobMode = DmaBounds_des(in); x->dgeMode = Nat_des(in); x->uniqueIndices = Bool_des(in); + x->priority = Option_Nat_des(in); x->engine = Engine_des(in); return x; } @@ -6464,9 +6478,9 @@ Ptr CollectiveOp_des(FILE *in) { msg << "Could not find tag, expecting CollectiveOp:199,0"; throw std::runtime_error(msg.str()); } - if (t != 199 || c != 0 || l != 8) { + if (t != 199 || c != 0 || l != 9) { std::ostringstream msg; - msg << "Expecting CollectiveOp:(199,0,8)"; + msg << "Expecting CollectiveOp:(199,0,9)"; msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; throw std::runtime_error(msg.str()); } @@ -6479,6 +6493,7 @@ Ptr CollectiveOp_des(FILE *in) { x->sourceTargetPairs = Option_List_List_Int_des(in); x->channel_id = Option_Int_des(in); x->num_channels = Option_Int_des(in); + x->priority = Option_Nat_des(in); return x; }