Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions KLR/Core/Operators.lean
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ structure NcDmaCopy where
oobMode : DmaBounds
dgeMode : Nat
uniqueIndices : Bool
priority : Option Nat
engine : Engine
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

Expand All @@ -415,6 +416,7 @@ structure DmaTranspose where
dtype : Option Dtype
dgeMode : Nat
oobMode : DmaBounds
priority : Option Nat
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

instance : MapTensorRefs DmaTranspose where
Expand Down Expand Up @@ -824,6 +826,8 @@ structure TensorScalarReduce where
dtype : Option Dtype
reduceOp : Option AluOp
reduceRes : TensorRef
reduceCmd : AccumCmd
reduceInit : Option Immediate
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

instance : MapTensorRefs TensorScalarReduce where
Expand Down Expand Up @@ -1037,6 +1041,7 @@ structure CollectiveOp where
sourceTargetPairs : Option (List (List Int)) := none
channel_id : Option Int := none
num_channels : Option Int := none
priority : Option Nat := none
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

instance : MapTensorRefs CollectiveOp where
Expand Down
22 changes: 22 additions & 0 deletions KLR/Trace/ISA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,8 @@ nki builtin.isa.tensor_scalar_reduce
(reduce_op : AluOp)
(reduce_res : Access)
(reverse0 : Bool := false)
(reduce_cmd : AccumCmd := .Zero)
(reduce_init : Option Immediate := none)
(mask : Option Immediate := none)
(name : Option String := none) := do
if mask.isSome then throw maskNotSupported
Expand All @@ -469,6 +471,8 @@ nki builtin.isa.tensor_scalar_reduce
reduceRes := .abstract reduce_res
reverse0 := reverse0
dtype := dst.tensor.dtype
reduceCmd := reduce_cmd
reduceInit := reduce_init
}) name
return .none

Expand Down Expand Up @@ -745,6 +749,7 @@ nki builtin.isa.dma_copy
(dge_mode : Nat := 0)
(unique_indices : Bool := false)
(engine : Engine := .unassigned)
(priority : Option Nat := .none)
(name : Option String := none) := do
if mask.isSome then throw maskNotSupported
let op : DgeComputeOp := <- match dst_rmw_op with
Expand All @@ -763,6 +768,7 @@ nki builtin.isa.dma_copy
| _ => .skip,
dgeMode := dge_mode,
uniqueIndices := unique_indices
priority
engine
}) name
return .none
Expand All @@ -775,6 +781,7 @@ nki builtin.isa.dma_transpose
(mask : Option Immediate := none)
(dge_mode : Nat := 0)
(oob_mode : Nat := 0)
(priority : Option Nat := .none)
(name : Option String := none) := do
if mask.isSome then throw maskNotSupported
if oob_mode > 1 then throw "unsupported oob mode"
Expand All @@ -792,6 +799,7 @@ nki builtin.isa.dma_transpose
| 0 => .error
| 1 => .skip
| _ => .error,
priority
}) name
return .none

Expand Down Expand Up @@ -999,12 +1007,14 @@ nki builtin.isa.all_reduce
(srcs : List Access)
(dsts : List Access)
(replica_group: Sum String (List (List Int)))
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.allReduce {
dsts := dsts.map .abstract
srcs := srcs.map .abstract
op := some op
replicaGroup := replicaGroupFromSum replica_group
priority
}) name
return .none

Expand All @@ -1014,12 +1024,14 @@ nki builtin.isa.all_gather
(dsts : List Access)
(replica_group : Sum String (List (List Int)))
(concat_dim : Int)
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.allGather {
dsts := dsts.map .abstract
srcs := srcs.map .abstract
replicaGroup := replicaGroupFromSum replica_group
concatDim := some concat_dim
priority
}) name
return .none

Expand All @@ -1030,13 +1042,15 @@ nki builtin.isa.reduce_scatter
(dsts : List Access)
(replica_group : Sum String (List (List Int)))
(concat_dim : Int)
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.reduceScatter {
dsts := dsts.map .abstract
srcs := srcs.map .abstract
op := some op
replicaGroup := replicaGroupFromSum replica_group
concatDim := some concat_dim
priority
}) name
return .none

Expand All @@ -1046,12 +1060,14 @@ nki builtin.isa.all_to_all
(dsts : List Access)
(replica_group : Sum String (List (List Int)))
(concat_dim : Int)
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.allToAll {
dsts := dsts.map .abstract
srcs := srcs.map .abstract
replicaGroup := replicaGroupFromSum replica_group
concatDim := some concat_dim
priority
}) name
return .none

Expand All @@ -1060,11 +1076,13 @@ nki builtin.isa.collective_permute
(src : Access)
(dst : Access)
(source_target_pairs: List (List Int))
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.collectivePermute {
dsts := [.abstract dst]
srcs := [.abstract src]
sourceTargetPairs := some source_target_pairs
priority
}) name
return .none

Expand All @@ -1075,13 +1093,15 @@ nki builtin.isa.collective_permute_implicit
(replica_group : Sum String (List (List Int)))
(channel_id : Int)
(num_channels : Int := 1)
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.collectivePermuteImplicit {
dsts := [.abstract dst]
srcs := [.abstract src]
replicaGroup := replicaGroupFromSum replica_group
channel_id := some channel_id
num_channels := some num_channels
priority
}) name
return .none

Expand All @@ -1094,6 +1114,7 @@ nki builtin.isa.collective_permute_implicit_reduce
(replica_group : Sum String (List (List Int)))
(channel_id : Int)
(num_channels : Int := 1)
(priority : Option Nat := none)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.collectivePermuteImplicitReduce {
dsts := [.abstract dst]
Expand All @@ -1102,6 +1123,7 @@ nki builtin.isa.collective_permute_implicit_reduce
replicaGroup := replicaGroupFromSum replica_group
channel_id := some channel_id
num_channels := some num_channels
priority
}) name
return .none

Expand Down
1 change: 1 addition & 0 deletions KLR/Trace/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def addId : Trace Unit := do
oobMode := .skip,
dgeMode := 0,
uniqueIndices := false
priority := .none
engine := .unassigned
}) none pos
let lbl := (<- genLabel `init)
Expand Down
8 changes: 4 additions & 4 deletions interop/klr/NKI.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ AffineSelect = (TensorRef dst, TensorRef src, AffineSelectCmp fillMode, Reg fill

DmaCopy = (TensorRef dst, TensorRef src, DgeComputeOp compute_op, DmaBounds dstBoundsCheck, DmaBounds srcBoundsCheck)

DmaTranspose = (TensorRef dst, TensorRef src, TransposeOps axes, Dtype? dtype, Nat dgeMode, DmaBounds oobMode)
DmaTranspose = (TensorRef dst, TensorRef src, TransposeOps axes, Dtype? dtype, Nat dgeMode, DmaBounds oobMode, Nat? priority)

Transpose = (TensorRef dst, TensorRef src, Dtype? dtype, Engine engine)

Expand Down Expand Up @@ -267,15 +267,15 @@ TensorTensor = (TensorRef dst, TensorRef src0, TensorRef src1, AluOp op, Dtype?

NcMatMul = (TensorRef dst, TensorRef stationary, TensorRef moving, Bool isStationaryOneZero, Bool isMovingZero, Bool isTranspose, Nat* tilePosition, Nat* tileSize, MatmulPerfMode perfMode)

TensorScalarReduce = (TensorRef dst, TensorRef src, Operand operand0, AluOp op0, Bool reverse0, Dtype? dtype, AluOp? reduceOp, TensorRef reduceRes)
TensorScalarReduce = (TensorRef dst, TensorRef src, Operand operand0, AluOp op0, Bool reverse0, Dtype? dtype, AluOp? reduceOp, TensorRef reduceRes, AccumCmd reduceCmd, Immediate? reduceInit)

TensorPartitionReduce = (TensorRef dst, AluOp op, TensorRef data, Dtype? dtype)

NcActivate = (TensorRef dst, TensorRef src, AccumCmd accumulatorCmd, ActivationFunc activationFunc, Operand scale, TensorRef? bias, AluOp? reduceOp, TensorRef? reduceRes, Dtype? dtype)

NcAffineSelect = (TensorRef dst, DataPattern pred, TensorRef onTrueTile, Immediate onFalseValue, Dtype? dtype, AluOp cmpOp)

NcDmaCopy = (TensorRef dst, TensorRef src, DgeComputeOp compute_op, DmaBounds oobMode, Nat dgeMode, Bool uniqueIndices, Engine engine)
NcDmaCopy = (TensorRef dst, TensorRef src, DgeComputeOp compute_op, DmaBounds oobMode, Nat dgeMode, Bool uniqueIndices, Nat? priority, Engine engine)

NcLocalGather = (TensorRef dst, TensorRef src, TensorRef index, Immediate numElemPerIdx, Immediate? numValidIndicies)

Expand Down Expand Up @@ -304,7 +304,7 @@ ReplicaGroup =
| literal(Int** groups)


CollectiveOp = (TensorRef* dsts, TensorRef* srcs, AluOp? op, ReplicaGroup replicaGroup, Int? concatDim, Int**? sourceTargetPairs, Int? channel_id, Int? num_channels)
CollectiveOp = (TensorRef* dsts, TensorRef* srcs, AluOp? op, ReplicaGroup replicaGroup, Int? concatDim, Int**? sourceTargetPairs, Int? channel_id, Int? num_channels, Nat? priority)

RankId = (String dst)

Expand Down
5 changes: 5 additions & 0 deletions interop/klr/klir_ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ struct DmaTranspose final {
Option<Dtype> dtype;
Nat dgeMode;
Ptr<DmaBounds> oobMode;
Option<Nat> priority;
};

struct Transpose final {
Expand Down Expand Up @@ -737,6 +738,8 @@ struct TensorScalarReduce final {
Option<Dtype> dtype;
Option<AluOp> reduceOp;
Ptr<TensorRef> reduceRes;
AccumCmd reduceCmd;
Option<Ptr<Immediate>> reduceInit;
};

struct TensorPartitionReduce final {
Expand Down Expand Up @@ -774,6 +777,7 @@ struct NcDmaCopy final {
Ptr<DmaBounds> oobMode;
Nat dgeMode;
Bool uniqueIndices;
Option<Nat> priority;
Engine engine;
};

Expand Down Expand Up @@ -910,6 +914,7 @@ struct CollectiveOp final {
Option<List<List<Int>>> sourceTargetPairs;
Option<Int> channel_id;
Option<Int> num_channels;
Option<Nat> priority;
};

struct RankId final {
Expand Down
31 changes: 31 additions & 0 deletions interop/klr/klir_pretty_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,13 @@ std::string to_string(DmaTranspose &DmaTransposeInstance) {
result += ", ";
result += "oobMode=";
result += to_string(*(DmaTransposeInstance.oobMode.get()));
result += ", ";
result += "priority=";
if (DmaTransposeInstance.priority.has_value()) {
result += std::to_string(DmaTransposeInstance.priority.value());
} else {
result += "None";
}
result += ")";
return result;
};
Expand Down Expand Up @@ -2000,6 +2007,16 @@ std::string to_string(TensorScalarReduce &TensorScalarReduceInstance) {
result += ", ";
result += "reduceRes=";
result += to_string(*(TensorScalarReduceInstance.reduceRes.get()));
result += ", ";
result += "reduceCmd=";
result += to_string(TensorScalarReduceInstance.reduceCmd); // mapped from enum
result += ", ";
result += "reduceInit=";
if (TensorScalarReduceInstance.reduceInit.has_value()) {
result += to_string(*(TensorScalarReduceInstance.reduceInit.value().get()));
} else {
result += "None";
}
result += ")";
return result;
};
Expand Down Expand Up @@ -2127,6 +2144,13 @@ std::string to_string(NcDmaCopy &NcDmaCopyInstance) {
result += "uniqueIndices=";
result += std::to_string(NcDmaCopyInstance.uniqueIndices);
result += ", ";
result += "priority=";
if (NcDmaCopyInstance.priority.has_value()) {
result += std::to_string(NcDmaCopyInstance.priority.value());
} else {
result += "None";
}
result += ", ";
result += "engine=";
result += to_string(NcDmaCopyInstance.engine); // mapped from enum
result += ")";
Expand Down Expand Up @@ -2659,6 +2683,13 @@ std::string to_string(CollectiveOp &CollectiveOpInstance) {
} else {
result += "None";
}
result += ", ";
result += "priority=";
if (CollectiveOpInstance.priority.has_value()) {
result += std::to_string(CollectiveOpInstance.priority.value());
} else {
result += "None";
}
result += ")";
return result;
};
Expand Down
Loading
Loading