Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -4517,6 +4517,44 @@ pto.treshape ins(%src : !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>)

---

##### `pto.tassemble` - Insert Sub-Tile Window

**Summary:** Inserts a source tile into a destination tile at a given row/col offset.

**Semantics:**

```
dst[i + indexRow, j + indexCol] = src[i, j]
```

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile |
| `indexRow` | `Index` | Destination row offset |
| `indexCol` | `Index` | Destination column offset |
| `dst` | `pto.tile_buf` | Destination tile |

**Results:** None. Writes into `dst` via DPS pattern.

**Constraints & Verification:**

- The operation has a custom verifier

**Hardware Mapping:**

- Lowers to **`TINSERT(dst, src, indexRow, indexCol)`**
- Uses target data-movement pipeline (MTE1 by default; A5 UB->L1 path uses MTE3)

**Basic Example:**

```mlir
pto.tassemble ins(%src, %row, %col : !pto.tile_buf<...>, index, index) outs(%dst : !pto.tile_buf<...>)
```

---

##### `pto.textract` - Extract Sub-Tile Window

**Summary:** Extracts a sub-tile window from a source tile into a destination tile.
Expand Down
44 changes: 44 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,50 @@ def TExtractOp : PTO_TOp<"textract", [
}];
}

def TAssembleOp : PTO_TOp<"tassemble", [
PTO_DpsInitOpInterface,
OpPipeInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Insert src sub-tile into dst at (indexRow, indexCol) (tilebuf, DPS)";

let arguments = (ins
PTODpsType:$src,
Index:$indexRow,
Index:$indexCol,
PTODpsType:$dst
);

let results = (outs);

let hasVerifier = 1;

let assemblyFormat = [{
`ins` `(` $src `,` $indexRow `,` $indexCol `:` qualified(type($src)) `,` type($indexRow) `,` type($indexCol) `)`
`outs` `(` $dst `:` qualified(type($dst) ) `)`
attr-dict
}];

let extraClassDeclaration = [{
// TINSERT runs on different DMA pipes across targets.
// - A5 (Ascend950/910_95): UB->L1 path is MTE3 in pto-isa custom kernels.
// - Others: keep MTE1 for compatibility with existing data-movement sync.
::mlir::pto::PIPE getPipe() {
auto moduleOp = getOperation()->getParentOfType<::mlir::ModuleOp>();
if (moduleOp) {
if (auto spec = moduleOp->getAttrOfType<::mlir::StringAttr>("pto.device-spec")) {
auto s = spec.getValue();
if (s.starts_with("Ascend950") || s.starts_with("Ascend910_95")) {
return ::mlir::pto::PIPE::PIPE_MTE3;
}
}
}
return ::mlir::pto::PIPE::PIPE_MTE1;
}
::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); }
}];
}

def TFillPadOp : PTO_TOp<"tfillpad", [
PTO_DpsInitOpInterface,
OpPipeInterface,
Expand Down
76 changes: 76 additions & 0 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,75 @@ mlir::LogicalResult mlir::pto::TExtractOp::verify() {
return mlir::success();
}
//===----------------------------------------------------------------------===//
// TAssembleOp_DPS verifier
//===----------------------------------------------------------------------===//

mlir::LogicalResult mlir::pto::TAssembleOp::verify() {
Type srcTy = getSrc().getType();
Type dstTy = getDst().getType();
if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy))
return emitOpError("expects src/dst to be PTO shaped-like types");

auto srcShape = getShapeVec(srcTy);
auto dstShape = getShapeVec(dstTy);
if (srcShape.size() != 2 || dstShape.size() != 2)
return emitOpError("expects rank-2 shaped types for src/dst");

Type srcElemTy = getElemTy(srcTy);
Type dstElemTy = getElemTy(dstTy);
bool sameElemTy = srcElemTy == dstElemTy;
bool castElemTy =
srcElemTy.isF32() && (dstElemTy.isF16() || dstElemTy.isBF16());
if (!sameElemTy && !castElemTy)
return emitOpError(
"expects src/dst element types to match, or src=f32 with dst=f16/bf16");

if (!getIndexRow().getType().isIndex() || !getIndexCol().getType().isIndex())
return emitOpError("expects indexRow/indexCol to be index type");

auto readConstIndex = [&](Value v, int64_t &out) -> bool {
if (auto cOp = v.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
out = cOp.value();
return true;
}
if (auto cInt = v.getDefiningOp<mlir::arith::ConstantIntOp>()) {
out = cInt.value();
return true;
}
if (auto cOp = v.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto ia = mlir::dyn_cast<mlir::IntegerAttr>(cOp.getValue())) {
out = ia.getInt();
return true;
}
}
return false;
};

int64_t r0 = 0;
int64_t c0 = 0;
bool rowConst = readConstIndex(getIndexRow(), r0);
bool colConst = readConstIndex(getIndexCol(), c0);
if (rowConst && r0 < 0)
return emitOpError("indexRow must be non-negative");
if (colConst && c0 < 0)
return emitOpError("indexCol must be non-negative");

int64_t srcRows = srcShape[0];
int64_t srcCols = srcShape[1];
int64_t dstRows = dstShape[0];
int64_t dstCols = dstShape[1];
if (rowConst && srcRows != mlir::ShapedType::kDynamic &&
dstRows != mlir::ShapedType::kDynamic &&
r0 + srcRows > dstRows)
return emitOpError("indexRow + src rows exceeds dst rows");
if (colConst && srcCols != mlir::ShapedType::kDynamic &&
dstCols != mlir::ShapedType::kDynamic &&
c0 + srcCols > dstCols)
return emitOpError("indexCol + src cols exceeds dst cols");

return mlir::success();
}
//===----------------------------------------------------------------------===//
// TFillPadOp_DPS verifier
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -4282,6 +4351,13 @@ void TExtractOp::getEffects(
PTO_ADD_WRITE(getDstMutable());
}

// TASSEMBLE: Read(src) -> Write(dst)
void TAssembleOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
PTO_ADD_READ(getSrcMutable());
PTO_ADD_WRITE(getDstMutable());
}

PTO_DEFINE_UNARY_EFFECTS(TFillPadOp, getSrcMutable(), getDstMutable())
PTO_DEFINE_UNARY_EFFECTS(TFillPadExpandOp, getSrcMutable(), getDstMutable())

Expand Down
27 changes: 26 additions & 1 deletion lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4743,6 +4743,31 @@ struct PTOExtractToEmitC : public OpConversionPattern<pto::TExtractOp> {
}
};
//===----------------------------------------------------------------------===//
// pto.tassemble lowering -> TINSERT(dst, src, indexRow, indexCol)
//===----------------------------------------------------------------------===//

struct PTOAssembleToEmitC : public OpConversionPattern<pto::TAssembleOp> {
using OpConversionPattern<pto::TAssembleOp>::OpConversionPattern;

LogicalResult matchAndRewrite(pto::TAssembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();

Value src = peelUnrealized(adaptor.getSrc());
Value dst = peelUnrealized(adaptor.getDst());
Value r0 = peelUnrealized(adaptor.getIndexRow());
Value c0 = peelUnrealized(adaptor.getIndexCol());

rewriter.create<emitc::CallOpaqueOp>(
loc, TypeRange{}, "TINSERT",
/*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{},
/*operands=*/ValueRange{dst, src, r0, c0});

rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// pto.tfillpad lowering -> TFILLPAD(dst, src)
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -7298,7 +7323,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns,
patterns.add<PTOExpandsToEmitC>(typeConverter, ctx);
patterns.add<PTOOrToEmitC>(typeConverter, ctx);
patterns.add<PTOPartAddToEmitC>(typeConverter, ctx);
patterns.add<PTOExtractToEmitC>(typeConverter, ctx);
patterns.add<PTOExtractToEmitC, PTOAssembleToEmitC>(typeConverter, ctx);
patterns.add<PTOFillPadToEmitC, PTOFillPadExpandToEmitC>(typeConverter, ctx);
patterns.add<PTOGatherToEmitC>(typeConverter, ctx);
patterns.add<PTOGatherbToEmitC>(typeConverter, ctx);
Expand Down
4 changes: 4 additions & 0 deletions test/samples/Assemble/assemble.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Regex patterns that must appear in emitted C++ for assemble.py
TLOAD\(
TINSERT\(
TSTORE\(
148 changes: 148 additions & 0 deletions test/samples/Assemble/assemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from mlir.ir import Context, Location, Module, InsertionPoint
from mlir.dialects import func, arith, pto
from mlir.ir import F16Type, F32Type, IndexType


def build():
with Context() as ctx:
pto.register_dialect(ctx, load=True)

with Location.unknown(ctx):
m = Module.create()

f16 = F16Type.get(ctx)
f32 = F32Type.get(ctx)
idx = IndexType.get(ctx)

ptr_f32 = pto.PtrType.get(f32, ctx)
ptr_f16 = pto.PtrType.get(f16, ctx)

tv2_f32 = pto.TensorViewType.get(2, f32, ctx)
tv2_f16 = pto.TensorViewType.get(2, f16, ctx)

tile_view_f32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx)
tile_view_f16 = pto.PartitionTensorViewType.get([32, 32], f16, ctx)

mat = pto.AddressSpaceAttr.get(pto.AddressSpace.MAT, ctx)
left = pto.AddressSpaceAttr.get(pto.AddressSpace.LEFT, ctx)
right = pto.AddressSpaceAttr.get(pto.AddressSpace.RIGHT, ctx)
acc = pto.AddressSpaceAttr.get(pto.AddressSpace.ACC, ctx)
pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx)

cfg_mat_f32 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx),
pto.TileConfig.fractalABSize,
pd,
ctx,
)
cfg_left_f32 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx),
pto.TileConfig.fractalABSize,
pd,
ctx,
)
cfg_right_f32 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.ColMajor, ctx),
pto.TileConfig.fractalABSize,
pd,
ctx,
)
cfg_acc_f32 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx),
1024,
pd,
ctx,
)
cfg_mat_f16 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx),
pto.TileConfig.fractalABSize,
pd,
ctx,
)
cfg_left_f16 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.ColMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.RowMajor, ctx),
pto.TileConfig.fractalABSize,
pd,
ctx,
)
cfg_right_f16 = pto.TileBufConfigAttr.get(
pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx),
pto.SLayoutAttr.get(pto.SLayout.ColMajor, ctx),
pto.TileConfig.fractalABSize,
pd,
ctx,
)

a_mat_ty = pto.TileBufType.get([32, 32], f32, mat, [32, 32], cfg_mat_f32, ctx)
b_mat_ty = pto.TileBufType.get([32, 32], f32, mat, [32, 32], cfg_mat_f32, ctx)
a_left_ty = pto.TileBufType.get([32, 32], f32, left, [32, 32], cfg_left_f32, ctx)
b_right_ty = pto.TileBufType.get([32, 32], f32, right, [32, 32], cfg_right_f32, ctx)
src_acc_ty = pto.TileBufType.get([32, 32], f32, acc, [32, 32], cfg_acc_f32, ctx)
dst_mat_ty = pto.TileBufType.get([32, 32], f16, mat, [32, 32], cfg_mat_f16, ctx)
out_left_ty = pto.TileBufType.get([32, 32], f16, left, [32, 32], cfg_left_f16, ctx)
i_mat_ty = pto.TileBufType.get([32, 32], f16, mat, [32, 32], cfg_mat_f16, ctx)
i_right_ty = pto.TileBufType.get([32, 32], f16, right, [32, 32], cfg_right_f16, ctx)
out_acc_ty = pto.TileBufType.get([32, 32], f32, acc, [32, 32], cfg_acc_f32, ctx)

fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f16, ptr_f32], [])
with InsertionPoint(m.body):
fn = func.FuncOp("assemble_kernel", fn_ty)
entry = fn.add_entry_block()

with InsertionPoint(entry):
c0 = arith.ConstantOp(idx, 0).result
c1 = arith.ConstantOp(idx, 1).result
c32 = arith.ConstantOp(idx, 32).result

arg_a, arg_b, arg_i, arg_out = entry.arguments

tv_a = pto.MakeTensorViewOp(tv2_f32, arg_a, [c32, c32], [c32, c1]).result
tv_b = pto.MakeTensorViewOp(tv2_f32, arg_b, [c32, c32], [c32, c1]).result
tv_i = pto.MakeTensorViewOp(tv2_f16, arg_i, [c32, c32], [c32, c1]).result
tv_out = pto.MakeTensorViewOp(tv2_f32, arg_out, [c32, c32], [c32, c1]).result

sv_a = pto.PartitionViewOp(tile_view_f32, tv_a, offsets=[c0, c0], sizes=[c32, c32]).result
sv_b = pto.PartitionViewOp(tile_view_f32, tv_b, offsets=[c0, c0], sizes=[c32, c32]).result
sv_i = pto.PartitionViewOp(tile_view_f16, tv_i, offsets=[c0, c0], sizes=[c32, c32]).result
sv_out = pto.PartitionViewOp(tile_view_f32, tv_out, offsets=[c0, c0], sizes=[c32, c32]).result

a_mat = pto.AllocTileOp(a_mat_ty).result
b_mat = pto.AllocTileOp(b_mat_ty).result
a_left = pto.AllocTileOp(a_left_ty).result
b_right = pto.AllocTileOp(b_right_ty).result
src_acc = pto.AllocTileOp(src_acc_ty).result
dst_mat = pto.AllocTileOp(dst_mat_ty).result
out_left = pto.AllocTileOp(out_left_ty).result
i_mat = pto.AllocTileOp(i_mat_ty).result
i_right = pto.AllocTileOp(i_right_ty).result
out_acc = pto.AllocTileOp(out_acc_ty).result

pto.TLoadOp(None, sv_a, a_mat)
pto.TLoadOp(None, sv_b, b_mat)
pto.TMovOp(None, a_mat, a_left)
pto.TMovOp(None, b_mat, b_right)
pto.TMatmulOp(None, a_left, b_right, src_acc)

# Main operation under test: lowering must emit TINSERT(dst, src, row, col).
pto.TAssembleOp(src_acc, c0, c0, dst_mat)

pto.TLoadOp(None, sv_i, i_mat)
pto.TMovOp(None, dst_mat, out_left)
pto.TMovOp(None, i_mat, i_right)
pto.TMatmulOp(None, out_left, i_right, out_acc)
pto.TStoreOp(None, out_acc, sv_out)

func.ReturnOp([])

m.operation.verify()
return m


if __name__ == "__main__":
print(build())
Loading
Loading