Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/ir/PTO-IR-vf-vops-design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# PTO IR: VF/VOPS (vtile) design notes

This document is **restored from OpenClaw session logs** (jsonl) and recent chat decisions.
It records the intended IR surface syntax and canonicalization expectations for the VF/VOPS layer.

## Types

- `!pto.vtile<LANESxELEM>`
- example: `!pto.vtile<64xf32>`
- `!pto.uscalar<ELEM>`
- example: `!pto.uscalar<f32>`
- `!pto.preg`

## Target config

Attach `pto.target_config` on module or function:

```mlir
module attributes {
pto.target_config = #pto.target_config<arch=a3, isa="kirin9030", repeat_bytes=256, block_bytes=32, caps={}>
} {
func.func @k() { return }
}
```

## Core ops

- `pto.vf.scope { ... }`
- Predication:
- `pto.vpred.all`
- `pto.vpred.tail %count`
- Loads/stores:
- `pto.vload %tile, %row, %col, %pred`
- `pto.vstore %tile, %row, %col, %value, %pred`
- `pto.vload_tail %tile, %row, %col, %count`
- `pto.vstore_tail %tile, %row, %col, %count, %value`

## Canonicalization (pass: `-pto-canonicalize-vops`)

- `vload/vstore` with `vpred.tail(count)` should be rewritten to `vload_tail/vstore_tail`.
- If an operand is produced by `vload_tail(count)`, downstream binops/stores should use `vpred.tail(count)` / `vstore_tail`.
- If `count == lanes` (constant), tail ops may be simplified to non-tail ops.
- Conservative loop-invariant hoisting may move pure pto ops that do not depend on the induction variable out of a `scf.for`.
79 changes: 79 additions & 0 deletions include/PTO/IR/PTOAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -437,4 +437,83 @@ def TileBufConfigAttr : AttrDef<PTO_Dialect, "TileBufConfig"> {
}];
}

//===----------------------------------------------------------------------===//
// Target Config (VF/VOPS)
//===----------------------------------------------------------------------===//

// #pto.target_config<arch=a3|a5, isa="...", variant="...", repeat_bytes=256, block_bytes=32, caps={...}>

def TargetConfigAttr : AttrDef<PTO_Dialect, "TargetConfig"> {
let mnemonic = "target_config";
let summary = "Target configuration for VF/VOPS emission.";

let parameters = (ins
"mlir::StringAttr":$arch, // required: "a3" | "a5"
"mlir::StringAttr":$isa, // optional
"mlir::StringAttr":$variant, // optional
"mlir::IntegerAttr":$repeatBytes, // optional
"mlir::IntegerAttr":$blockBytes, // optional
"mlir::DictionaryAttr":$caps // optional, default empty dict
);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
static ::mlir::Attribute parse(::mlir::AsmParser &parser, ::mlir::Type) {
if (failed(parser.parseLess())) return {};
auto ctx = parser.getContext();

::mlir::StringAttr arch;
::mlir::StringAttr isa;
::mlir::StringAttr variant;
::mlir::IntegerAttr repeatBytes;
::mlir::IntegerAttr blockBytes;
::mlir::DictionaryAttr caps;

while (true) {
llvm::StringRef key;
if (failed(parser.parseKeyword(&key))) return {};
if (failed(parser.parseEqual())) return {};

if (key == "arch") {
llvm::StringRef av;
if (failed(parser.parseKeyword(&av))) return {};
if (av != "a3" && av != "a5") return {};
arch = ::mlir::StringAttr::get(ctx, av);
} else if (key == "isa") {
if (failed(parser.parseAttribute(isa))) return {};
} else if (key == "variant") {
if (failed(parser.parseAttribute(variant))) return {};
} else if (key == "repeat_bytes") {
if (failed(parser.parseAttribute(repeatBytes))) return {};
} else if (key == "block_bytes") {
if (failed(parser.parseAttribute(blockBytes))) return {};
} else if (key == "caps") {
if (failed(parser.parseAttribute(caps))) return {};
} else {
return {};
}

if (succeeded(parser.parseOptionalGreater())) break;
if (failed(parser.parseComma())) return {};
}

if (!arch) return {};
if (!caps) caps = ::mlir::DictionaryAttr::get(ctx);
return Base::get(ctx, arch, isa, variant, repeatBytes, blockBytes, caps);
}

void print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer << "arch=" << getArch().getValue();
if (getIsa()) printer << ", isa=" << getIsa();
if (getVariant()) printer << ", variant=" << getVariant();
if (getRepeatBytes()) printer << ", repeat_bytes=" << getRepeatBytes();
if (getBlockBytes()) printer << ", block_bytes=" << getBlockBytes();
if (getCaps() && !getCaps().empty()) printer << ", caps=" << getCaps();
printer << ">";
}
}];
}

#endif // MLIR_DIALECT_PTO_IR_PTOATTRS
173 changes: 173 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3645,4 +3645,177 @@ def TPrintOp: PTO_TOp<"tprint", [
}];
}

//===----------------------------------------------------------------------===//
// VF / VOPS (vector-tile ops)
//===----------------------------------------------------------------------===//

// pto.vf.scope { ... }
def VFScopeOp : PTO_Op<"vf.scope", [IsolatedFromAbove, NoRegionArguments, SingleBlockImplicitTerminator<"YieldOp">]> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow vf.scope to capture surrounding SSA values

Removing data flow in and out of pto.vf.scope makes the op effectively unusable for real kernels: with IsolatedFromAbove, no region arguments, and no yielded values, the body cannot read function arguments or previously computed tiles/pointers, so typical vload/vstore sequences cannot operate on surrounding state. This is a behavioral regression for any lowering that expects vf.scope to wrap vectorized code over existing operands.

Useful? React with 👍 / 👎.

let summary = "VF scope for explicit V-pipe ops.";
let regions = (region AnyRegion:$body);
let assemblyFormat = "attr-dict $body";
}

// yield terminator for vf.scope
def YieldOp : PTO_Op<"yield", [Terminator, HasParent<"VFScopeOp">]> {
let summary = "Terminator for pto.vf.scope";
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
}

// pto.vpred.all : !pto.preg
def VPredAllOp : PTO_Op<"vpred.all", [Pure]> {
let summary = "Create an all-true predicate.";
let results = (outs PregType:$pred);
let assemblyFormat = "attr-dict `:` type($pred)";
}

// pto.vpred.tail %count : !pto.preg
def VPredTailOp : PTO_Op<"vpred.tail", [Pure]> {
let summary = "Create a tail predicate for a given element count.";
let arguments = (ins Index:$count);
let results = (outs PregType:$pred);
let assemblyFormat = "$count attr-dict `:` type($pred)";
}

// pto.uload_row %tile, %row : !pto.uscalar<elem>
def ULoadRowOp : PTO_Op<"uload_row", [Pure]> {
let summary = "Uniform scalar load for RowExpand-like patterns.";
let arguments = (ins TileBufType:$tile, Index:$row);
let results = (outs UScalarType:$value);
let hasVerifier = 1;
let assemblyFormat = "$tile `,` $row attr-dict `:` type($value)";
}

// pto.vdup %u, %pred : !pto.vtile<...>
def VDupOp : PTO_Op<"vdup", [Pure]> {
let summary = "Duplicate a uniform scalar into a vtile under predicate.";
let arguments = (ins UScalarType:$src, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$src `,` $pred attr-dict `:` type($dst)";
}

// pto.vload %tile, %row, %col, %pred : !pto.vtile<...>
def VLoadOp : PTO_Op<"vload", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Vector load from tile buffer at (row,col) under predicate.";
let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, PregType:$pred);
let results = (outs VTileType:$value);
let hasVerifier = 1;
let assemblyFormat = "$tile `,` $row `,` $col `,` $pred attr-dict `:` type($value)";
}

// pto.vstore %tile, %row, %col, %value, %pred
def VStoreOp : PTO_Op<"vstore", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Vector store to tile buffer at (row,col) under predicate.";
let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, VTileType:$value, PregType:$pred);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = "$tile `,` $row `,` $col `,` $value `,` $pred attr-dict `:` type($value)";
}

// pto.vload_tail %tile, %row, %col, %count : !pto.vtile<...>
def VLoadTailOp : PTO_Op<"vload_tail", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Tail-safe vector load with explicit count.";
let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, Index:$count);
let results = (outs VTileType:$value);
let hasVerifier = 1;
let assemblyFormat = "$tile `,` $row `,` $col `,` $count attr-dict `:` type($value)";
}

// pto.vstore_tail %tile, %row, %col, %count, %value
def VStoreTailOp : PTO_Op<"vstore_tail", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Tail-safe vector store with explicit count.";
let arguments = (ins TileBufType:$tile, Index:$row, Index:$col, Index:$count, VTileType:$value);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = "$tile `,` $row `,` $col `,` $count `,` $value attr-dict `:` type($value)";
}

// pto.vload_block %tile, %row : !pto.vtile<...>
def VLoadBlockOp : PTO_Op<"vload_block", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Block load used by RowExpand block-broadcast patterns.";
let arguments = (ins TileBufType:$tile, Index:$row);
let results = (outs VTileType:$value);
let hasVerifier = 1;
let assemblyFormat = "$tile `,` $row attr-dict `:` type($value)";
}

// pto.vlane_adapt %blk : !pto.vtile<...>
def VLaneAdaptOp : PTO_Op<"vlane_adapt", [Pure]> {
let summary = "Adapt lanes from a block vtile to a full vtile.";
let arguments = (ins VTileType:$src);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$src attr-dict `:` type($dst)";
}

// Binops: (vtile, vtile, preg) -> vtile

def VAddOp : PTO_Op<"vadd", [Pure]> {
let summary = "Vector add.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VSubOp : PTO_Op<"vsub", [Pure]> {
let summary = "Vector sub.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VMulOp : PTO_Op<"vmul", [Pure]> {
let summary = "Vector mul.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VMinOp : PTO_Op<"vmin", [Pure]> {
let summary = "Vector min.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VMaxOp : PTO_Op<"vmax", [Pure]> {
let summary = "Vector max.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VAndOp : PTO_Op<"vand", [Pure]> {
let summary = "Vector and.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VOrOp : PTO_Op<"vor", [Pure]> {
let summary = "Vector or.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}

def VXorOp : PTO_Op<"vxor", [Pure]> {
let summary = "Vector xor.";
let arguments = (ins VTileType:$lhs, VTileType:$rhs, PregType:$pred);
let results = (outs VTileType:$dst);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $pred attr-dict `:` type($dst)";
}


#endif // MLIR_DIALECT_PTO_IR_PTOOPS
69 changes: 69 additions & 0 deletions include/PTO/IR/PTOTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,72 @@ def TileBufType : TypeDef<PTO_Dialect, "TileBuf"> {
int32_t getPadValueI32() const; // 0 null, 1 zero, 2 max, 3 min
}];
}

// ---- !pto.preg ----
// Predicate register used by vops.
def PregType : TypeDef<PTO_Dialect, "Preg"> {
let mnemonic = "preg";
let summary = "Predicate register type used by VOPS.";
}

// ---- !pto.uscalar<elem> ----
// Uniform scalar (per-thread uniform) used by vops patterns.
def UScalarType : TypeDef<PTO_Dialect, "UScalar"> {
let mnemonic = "uscalar";
let summary = "Uniform scalar value used for scalar+SIMD patterns inside pto.vf.scope.";
let parameters = (ins "mlir::Type":$elementType);

// Print/parse as: !pto.uscalar<elementType>
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
static ::mlir::Type parse(::mlir::AsmParser &parser) {
if (failed(parser.parseLess())) return {};
::mlir::Type elem;
if (failed(parser.parseType(elem))) return {};
if (failed(parser.parseGreater())) return {};
return Base::get(parser.getContext(), elem);
}

void print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer.printType(getElementType());
printer << ">";
}
}];
}

// ---- !pto.vtile<elem, lanes> ----
// Vector tile value used in vops. Lanes is typically elements-per-repeat (EPR).
def VTileType : TypeDef<PTO_Dialect, "VTile"> {
let mnemonic = "vtile";
let summary = "Vector tile value used inside pto.vf.scope (maps to vreg on A5, subtile view on A3).";
let parameters = (ins
"mlir::Type":$elementType,
"int64_t":$lanes
);

// Print/parse as: !pto.vtile<lanes x elementType>
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
static ::mlir::Type parse(::mlir::AsmParser &parser) {
if (failed(parser.parseLess())) return {};
int64_t lanes = 0;
if (failed(parser.parseInteger(lanes))) return {};
if (failed(parser.parseKeyword("x"))) return {};
::mlir::Type elem;
if (failed(parser.parseType(elem))) return {};
if (failed(parser.parseGreater())) return {};
return Base::get(parser.getContext(), elem, lanes);
}

void print(::mlir::AsmPrinter &printer) const {
printer << "<" << getLanes() << "x";
printer.printType(getElementType());
printer << ">";
}

int64_t lanes() const { return getLanes(); }
}];
}
12 changes: 12 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,16 @@ def PTOLoweringSyncToPipe : Pass<"pto-lowering-sync-to-pipe", "func::FuncOp"> {
];
}



def PTOCanonicalizeVops : Pass<"pto-canonicalize-vops", "func::FuncOp"> {
let summary = "Canonicalize VOPS patterns (tail ops, pred propagation, hoisting).";
let constructor = "mlir::pto::createPTOCanonicalizeVopsPass()";
let dependentDialects = [
"mlir::pto::PTODialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"
];
}

#endif // MLIR_DIALECT_PTO_PASSES
Loading
Loading