Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 161 additions & 20 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ static constexpr unsigned kPTOIndexBitWidth =
32; // keep consistent with IndexType conversion

// Forward declarations (definitions below).
static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a);
static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx,
unsigned bitWidth);
static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx,
Expand All @@ -322,6 +323,104 @@ static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewri
Location loc, Value v,
unsigned bitWidth);

static bool isSetFFTsPointerLikeType(Type ty) {
if (isa<emitc::PointerType>(ty))
return true;
if (auto opaqueTy = dyn_cast<emitc::OpaqueType>(ty))
return opaqueTy.getValue().ends_with("*");
return false;
}

struct InterCoreSyncCallDesc {
const char *callee = nullptr;
ArrayAttr args;
SmallVector<Value, 2> operands;
};

static InterCoreSyncCallDesc buildInterCoreSyncSetCall(
ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch,
pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) {
auto *ctx = rewriter.getContext();
std::string pipeTok = pipeTokFromPipeAttr(pipeAttr);

if (targetArch == PTOArch::A3) {
auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t");
Value eventVal =
makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt());

auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t");
auto msgArgs = rewriter.getArrayAttr({
emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"),
IntegerAttr::get(IndexType::get(ctx), 0),
});
Value msgVal =
rewriter
.create<emitc::CallOpaqueOp>(loc, msgTy, "getFFTSMsg",
/*args=*/msgArgs,
/*templateArgs=*/ArrayAttr{},
/*operands=*/ValueRange{eventVal})
.getResult(0);

InterCoreSyncCallDesc desc;
desc.callee = "ffts_cross_core_sync";
desc.args = rewriter.getArrayAttr({
emitc::OpaqueAttr::get(ctx, pipeTok),
IntegerAttr::get(IndexType::get(ctx), 0),
});
desc.operands.push_back(msgVal);
return desc;
}

InterCoreSyncCallDesc desc;
desc.callee = "set_intra_block";
desc.args = rewriter.getArrayAttr(
{emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr});
return desc;
}

static InterCoreSyncCallDesc buildInterCoreSyncWaitCall(
ConversionPatternRewriter &rewriter, PTOArch targetArch,
pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) {
auto *ctx = rewriter.getContext();
std::string pipeTok = pipeTokFromPipeAttr(pipeAttr);

InterCoreSyncCallDesc desc;
if (targetArch == PTOArch::A3) {
desc.callee = "wait_flag_dev";
desc.args = rewriter.getArrayAttr({eventIdAttr});
return desc;
}

desc.callee = "wait_intra_block";
desc.args = rewriter.getArrayAttr(
{emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr});
return desc;
}

static bool hasInterCoreSyncOp(func::FuncOp func) {
bool found = false;
func.walk([&](Operation *op) {
if (isa<pto::SyncSetOp, pto::SyncWaitOp>(op)) {
found = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return found;
}

static bool hasSetFFTsOp(func::FuncOp func) {
bool found = false;
func.walk([&](Operation *op) {
if (isa<pto::SetFFTsOp>(op)) {
found = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return found;
}

//===----------------------------------------------------------------------===//
// Arith -> EmitC (full dialect coverage for scalar ops)
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3686,6 +3785,41 @@ struct PTORlsBufToEmitC : public OpConversionPattern<mlir::pto::RlsBufOp> {
}
};

struct PTOSetFFTsToEmitC : public OpConversionPattern<mlir::pto::SetFFTsOp> {
using OpConversionPattern<mlir::pto::SetFFTsOp>::OpConversionPattern;

LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();

Value fftsAddr = peelUnrealized(adaptor.getFfts());
auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t");

if (isSetFFTsPointerLikeType(fftsAddr.getType())) {
auto castTyAttr =
rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")});
fftsAddr =
rewriter
.create<emitc::CallOpaqueOp>(loc, u64Ty, "reinterpret_cast",
/*args=*/ArrayAttr{},
/*templateArgs=*/castTyAttr,
/*operands=*/ValueRange{fftsAddr})
.getResult(0);
} else if (fftsAddr.getType() != u64Ty) {
fftsAddr =
rewriter.create<emitc::CastOp>(loc, u64Ty, fftsAddr).getResult();
}

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, TypeRange{}, "set_ffts_base_addr",
/*args=*/ArrayAttr{},
/*templateArgs=*/ArrayAttr{},
/*operands=*/ValueRange{fftsAddr});
return success();
}
};

struct PTOSyncSetToEmitC : public OpConversionPattern<mlir::pto::SyncSetOp> {
PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx,
PTOArch targetArch)
Expand All @@ -3696,19 +3830,13 @@ struct PTOSyncSetToEmitC : public OpConversionPattern<mlir::pto::SyncSetOp> {
matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
(void)adaptor;
auto *ctx = rewriter.getContext();
auto loc = op->getLoc();

std::string pipeTok = pipeTokFromPipeAttr(op.getPipe());
auto argsAttr = rewriter.getArrayAttr(
{emitc::OpaqueAttr::get(ctx, pipeTok), op.getEventIdAttr()});
const char *kSyncSetCallee = (targetArch == PTOArch::A3)
? "ffts_cross_core_sync"
: "set_intra_block";
rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, kSyncSetCallee,
/*args=*/argsAttr,
auto desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(),
op.getEventIdAttr());
rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, desc.callee,
/*args=*/desc.args,
/*templateArgs=*/ArrayAttr{},
/*operands=*/ValueRange{});
/*operands=*/desc.operands);

rewriter.eraseOp(op);
return success();
Expand All @@ -3727,16 +3855,11 @@ struct PTOSyncWaitToEmitC : public OpConversionPattern<mlir::pto::SyncWaitOp> {
matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
(void)adaptor;
auto *ctx = rewriter.getContext();
auto loc = op->getLoc();

std::string pipeTok = pipeTokFromPipeAttr(op.getPipe());
auto argsAttr = rewriter.getArrayAttr(
{emitc::OpaqueAttr::get(ctx, pipeTok), op.getEventIdAttr()});
const char *kSyncWaitCallee =
(targetArch == PTOArch::A3) ? "wait_flag_dev" : "wait_intra_block";
rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, kSyncWaitCallee,
argsAttr, ArrayAttr{}, ValueRange{});
auto desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(),
op.getEventIdAttr());
rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, desc.callee,
desc.args, ArrayAttr{}, desc.operands);

rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -7258,6 +7381,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns,
patterns.add<PTOWaitFlagToEmitC>(typeConverter, ctx);
patterns.add<PTOGetBufToEmitC>(typeConverter, ctx);
patterns.add<PTORlsBufToEmitC>(typeConverter, ctx);
patterns.add<PTOSetFFTsToEmitC>(typeConverter, ctx);
patterns.add<PTOXORSToEmitC>(typeConverter, ctx);
patterns.add<PTOSYNCToEmitC>(typeConverter, ctx);
patterns.add<PTOSubSToEmitC>(typeConverter, ctx);
Expand Down Expand Up @@ -7456,6 +7580,23 @@ struct EmitPTOManualPass
MLIRContext *ctx = &getContext();
ModuleOp mop = getOperation();

// A3 requires explicit FFTS base setup for inter-core sync ops.
if (targetArch == PTOArch::A3) {
bool hasMissingSetFFTs = false;
for (auto func : mop.getOps<func::FuncOp>()) {
if (!hasInterCoreSyncOp(func))
continue;
if (hasSetFFTsOp(func))
continue;
hasMissingSetFFTs = true;
func.emitError()
<< "A3 inter-core sync requires explicit `pto.set_ffts` in the "
"same function when using `pto.sync.set`/`pto.sync.wait`";
}
if (hasMissingSetFFTs)
return signalPassFailure();
}

// 1. 插入头文件
auto loc = mop->getLoc();
OpBuilder builder(ctx);
Expand Down
37 changes: 37 additions & 0 deletions python/pto/dialects/pto.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def _load_local_pto_ext():
"TileConfig",
# High-level sync helpers
"record_event", "wait_event", "barrier",
# Inter-core sync helpers
"sync_set", "sync_wait", "set_ffts",
# A5 buffer-id sync helpers
"get_buf", "rls_buf",
# Scalar pointer helpers
Expand Down Expand Up @@ -168,6 +170,41 @@ def barrier(op, *, loc=None, ip=None):
# Otherwise fall back to low-level barrier expecting PipeAttr
return _pto_ops_gen.barrier(op, loc=loc, ip=ip)

# -----------------------------------------------------------------------------
# Inter-core sync helpers (pto.sync.set / pto.sync.wait / pto.set_ffts)
# -----------------------------------------------------------------------------
def sync_set(pipe, event_id, *, loc=None, ip=None):
ctx = loc.context if loc else _ods_ir.Context.current
return _ods_ir.Operation.create(
"pto.sync.set",
attributes={
"pipe": _ensure_pipe_attr(pipe, ctx),
"event_id": _ensure_i32_attr(event_id, "event_id", ctx),
},
loc=loc,
ip=ip,
)

def sync_wait(pipe, event_id, *, loc=None, ip=None):
ctx = loc.context if loc else _ods_ir.Context.current
return _ods_ir.Operation.create(
"pto.sync.wait",
attributes={
"pipe": _ensure_pipe_attr(pipe, ctx),
"event_id": _ensure_i32_attr(event_id, "event_id", ctx),
},
loc=loc,
ip=ip,
)

def set_ffts(ffts, *, loc=None, ip=None):
return _ods_ir.Operation.create(
"pto.set_ffts",
operands=[_pto_ops_gen._get_op_result_or_value(ffts)],
loc=loc,
ip=ip,
)

# -----------------------------------------------------------------------------
# A5 buffer-id sync helpers
# -----------------------------------------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions test/samples/Sync/test_intercore_sync_a3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
from mlir.ir import Context, InsertionPoint, IntegerType, Location, MemRefType, Module
from mlir.dialects import func, pto


def build():
with Context() as ctx:
pto.register_dialect(ctx, load=True)
with Location.unknown(ctx):
module = Module.create()

i64 = IntegerType.get_signless(64, ctx)
# Minimal valid memref operand for pto.set_ffts verifier (i64 element).
ffts_ty = MemRefType.get([1], i64)
fn_ty = func.FunctionType.get([ffts_ty], [])

with InsertionPoint(module.body):
fn = func.FuncOp("test_intercore_sync_a3", fn_ty)
entry = fn.add_entry_block()

with InsertionPoint(entry):
pipe_fix = pto.PipeAttr.get(pto.PIPE.PIPE_FIX, ctx)
pipe_v = pto.PipeAttr.get(pto.PIPE.PIPE_V, ctx)
pto.set_ffts(entry.arguments[0])
pto.sync_set(pipe_fix, 3)
pto.sync_wait(pipe_v, 3)
func.ReturnOp([])

module.operation.verify()
return module


if __name__ == "__main__":
print(build())
29 changes: 29 additions & 0 deletions test/samples/Sync/test_intercore_sync_a3_missing_setffts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
from mlir.ir import Context, InsertionPoint, Location, Module
from mlir.dialects import func, pto


def build():
with Context() as ctx:
pto.register_dialect(ctx, load=True)
with Location.unknown(ctx):
module = Module.create()
fn_ty = func.FunctionType.get([], [])

with InsertionPoint(module.body):
fn = func.FuncOp("test_intercore_sync_a3_missing_setffts", fn_ty)
entry = fn.add_entry_block()

with InsertionPoint(entry):
pipe_fix = pto.PipeAttr.get(pto.PIPE.PIPE_FIX, ctx)
pipe_v = pto.PipeAttr.get(pto.PIPE.PIPE_V, ctx)
pto.sync_set(pipe_fix, 7)
pto.sync_wait(pipe_v, 7)
func.ReturnOp([])

module.operation.verify()
return module


if __name__ == "__main__":
print(build())
29 changes: 29 additions & 0 deletions test/samples/Sync/test_intercore_sync_a5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
from mlir.ir import Context, InsertionPoint, Location, Module
from mlir.dialects import func, pto


def build():
with Context() as ctx:
pto.register_dialect(ctx, load=True)
with Location.unknown(ctx):
module = Module.create()
fn_ty = func.FunctionType.get([], [])

with InsertionPoint(module.body):
fn = func.FuncOp("test_intercore_sync_a5", fn_ty)
entry = fn.add_entry_block()

with InsertionPoint(entry):
pipe_fix = pto.PipeAttr.get(pto.PIPE.PIPE_FIX, ctx)
pipe_v = pto.PipeAttr.get(pto.PIPE.PIPE_V, ctx)
pto.sync_set(pipe_fix, 5)
pto.sync_wait(pipe_v, 5)
func.ReturnOp([])

module.operation.verify()
return module


if __name__ == "__main__":
print(build())
Loading
Loading