From acac3248d81f002c3e65fee1e65f590c7fca7299 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 12 Sep 2025 16:55:24 +0000 Subject: [PATCH 01/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- .../DotOpToLLVM/FMADotUtility.cpp | 177 ++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index fa2c814722..a6d25df7b6 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -1,5 +1,11 @@ #include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -71,10 +77,57 @@ ValueTableFMA getValueTableFromStructFMA( return res; } +// Create an empty loop using the given lower bound \p lb, upper bound \p ub and +// step \p step. Return the body block of the created loop. +Block *createEmptyLoop(Value iv, Value ub, Value step, + ConversionPatternRewriter &rewriter, Location loc) { + MLIRContext *ctx = rewriter.getContext(); + Block *insertionBlock = rewriter.getInsertionBlock(); + Block *headerBlock = + rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint()); + Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin()); + Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin()); + rewriter.setInsertionPointToEnd(insertionBlock); + + // Loop header. + rewriter.create(loc, headerBlock, SmallVector{iv}); + rewriter.setInsertionPointToStart(headerBlock); + auto b = TritonLLVMOpBuilder(loc, rewriter); + rewriter.create(loc, b.icmp_slt(iv, ub), bodyBlock, + endBlock, SmallVector{iv}); + rewriter.setInsertionPointToStart(bodyBlock); + + // Loop body. + auto nextIv = b.add(iv, step); + rewriter.create(loc, headerBlock, SmallVector{nextIv}); + rewriter.setInsertionPointToStart(endBlock); + + return bodyBlock; +} + +// Initialize a variable to \p init and return the loaded value. +Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ptr = rewriter.create(loc, ptr_ty(rewriter.getContext()), + init.getType(), b.i32_val(1)); + rewriter.create(loc, init, ptr); + return rewriter.create(loc, init.getType(), ptr); +} + } // namespace namespace mlir::triton::gpu { +enum class CodeGenMode { + Unroll, + Loop, +} codeGenMode = CodeGenMode::Unroll; + +LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, + ArrayRef, ArrayRef, + ArrayRef, unsigned, Type, + ConversionPatternRewriter &); + LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, @@ -134,6 +187,12 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, SmallVector acc = cc; + if (codeGenMode == CodeGenMode::Loop) { + Type dType = typeConverter->convertType(dTensorTy); + return genFMALoop(op, has, hbs, acc, sizePerThread, repetitions, K, dType, + rewriter); + } + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) @@ -167,4 +226,122 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, return success(); } +LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, + ArrayRef acc, ArrayRef sizePerThread, + ArrayRef repetitions, unsigned K, Type dType, + ConversionPatternRewriter &rewriter) { + ModuleOp mod = op->getParentOfType(); + MLIRContext *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + // Copy struct into vector for operand A. + SmallVector v1; + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned k = 0; k < K; ++k) + v1.push_back(has.at({bRep, mRep, b, m, k})); + Value vecA = packLLVector(loc, v1, rewriter); + + // Copy struct into vector for operand B. + SmallVector v2; + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned n = 0; n < sizePerThread[2]; ++n) + for (unsigned k = 0; k < K; ++k) + v2.push_back(hbs.at({bRep, nRep, b, n, k})); + Value vecB = packLLVector(loc, v2, rewriter); + + // Copy struct into vector for operand C. + Value vecC = packLLVector(loc, acc, rewriter); + + const unsigned len = acc.size(); + Type elemType = acc.front().getType(); + auto builder = TritonLLVMOpBuilder(loc, rewriter); + Value vecD = builder.undef(vec_ty(elemType, len)); + + Value zero = builder.i32_val(0), one = builder.i32_val(1); + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) { + // Generate the outer loop. + Value outerIV = createIV(zero, rewriter, loc); + Value outerUB = builder.i32_val(sizePerThread[1]); + Value outerStep = builder.i32_val(sizePerThread[2]); + Block *outerBody = + createEmptyLoop(outerIV, outerUB, outerStep, rewriter, loc); + auto afterOuterLoop = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(outerBody); + + // Get the values for operand A. + SmallVector AElems; + for (unsigned i = 0; i < sizePerThread[2]; ++i) { + Value idx = builder.add(outerIV, builder.i32_val(i)); + AElems.push_back(builder.extract_element(vecA, idx)); + } + + // Generate the inner loop. + Value innerIV = createIV(zero, rewriter, loc); + Value innerUB = outerStep; + Value innerStep = one; + Block *innerBody = + createEmptyLoop(outerIV, innerUB, innerStep, rewriter, loc); + rewriter.setInsertionPointToStart(innerBody); + + // Get the values for operand B. + SmallVector BElems; + for (unsigned j = 0; j < sizePerThread[2]; ++j) { + Value idx = + builder.add(innerIV, builder.i32_val(sizePerThread[2] * j)); + BElems.push_back(builder.extract_element(vecB, idx)); + } + + // Get the value for operand C. + // TODO: generate FMA for integer here. + Value accIdx = builder.fma(innerUB, outerIV, innerIV); + Value acc = builder.extract_element(vecC, accIdx); + + // Perform the FMAs. + for (unsigned k = 0; k < sizePerThread[2]; ++k) { + TypeSwitch(elemType) + .Case([&](auto) { + acc = rewriter.create(loc, AElems[k], + BElems[k], acc); + }) + .Case([&](auto) { + acc = builder.fma(AElems[k], BElems[k], acc); + }); + } + + // Store the result. + builder.insert_element(vecD, acc, accIdx); + rewriter.restoreInsertionPoint(afterOuterLoop); + } + + // Create a loop to copy vecD into a struct. + Value ub = builder.i32_val(len); + auto structPtr = + rewriter.create(loc, ptr_ty(ctx), elemType, ub); + Value iv = createIV(zero, rewriter, loc); + Block *body = createEmptyLoop(iv, ub, one, rewriter, loc); + auto afterLoop = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(body); + Value val = builder.extract_element(vecD, iv); + Value ptr = builder.gep(ptr_ty(ctx), val.getType(), structPtr, iv); + rewriter.create(loc, val, ptr); + rewriter.restoreInsertionPoint(afterLoop); + auto loadVal = rewriter.create(loc, dType, structPtr); + rewriter.replaceOp(op, loadVal); + + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "Module after:\n"; + mod->dumpPretty(); + llvm::errs() << "\n"; + + return success(); +} + } // namespace mlir::triton::gpu From ae0382a2730ff156be159da46c840adae461a44b Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 12 Sep 2025 16:58:49 +0000 Subject: [PATCH 02/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index a6d25df7b6..99cdb3315d 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -288,7 +288,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, Value innerUB = outerStep; Value innerStep = one; Block *innerBody = - createEmptyLoop(outerIV, innerUB, innerStep, rewriter, loc); + createEmptyLoop(innerIV, innerUB, innerStep, rewriter, loc); rewriter.setInsertionPointToStart(innerBody); // Get the values for operand B. From be0c48f54455b634ad93712f99199313a2f4ec82 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 12 Sep 2025 13:02:10 -0400 Subject: [PATCH 03/12] Update lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index 99cdb3315d..8487a80afe 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -301,7 +301,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, // Get the value for operand C. // TODO: generate FMA for integer here. - Value accIdx = builder.fma(innerUB, outerIV, innerIV); + Value accIdx = builder.add(builder.mul(innerUB, outerIV), innerIV); Value acc = builder.extract_element(vecC, accIdx); // Perform the FMAs. From b7b7720d8c0e4f14bd3b7de9ea694004bd2fb407 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 12 Sep 2025 13:02:28 -0400 Subject: [PATCH 04/12] Update lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index 8487a80afe..d272792a81 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -317,7 +317,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, } // Store the result. - builder.insert_element(vecD, acc, accIdx); + vecD = builder.insert_element(vecD, acc, accIdx); rewriter.restoreInsertionPoint(afterOuterLoop); } From 7af9b39a9d00e8fa22549612aefaf22a1a4a5af7 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Tue, 16 Sep 2025 15:25:43 +0000 Subject: [PATCH 05/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- .../DotOpToLLVM/FMADotUtility.cpp | 105 +++++++++++++----- 1 file changed, 78 insertions(+), 27 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index d272792a81..a8a30d6ab1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -77,35 +77,64 @@ ValueTableFMA getValueTableFromStructFMA( return res; } -// Create an empty loop using the given lower bound \p lb, upper bound \p ub and -// step \p step. Return the body block of the created loop. -Block *createEmptyLoop(Value iv, Value ub, Value step, - ConversionPatternRewriter &rewriter, Location loc) { +struct LoopInfo { + Block *header; + Block *body; + Block *end; +}; + +/// Creates an empty loop structure with a header, body, and end block. The +/// loop is initialized with an induction variable and an initial argument. +/// - Parameters: +/// - `iv`: The induction variable (already initialized to the lower bound). +/// - `ub`: The upper bound for the loop. +/// - `step`: The step for the induction variable. +/// - `initArg`: The initial argument passed to the loop. +/// - Returns a `LoopInfo` structure containing the header, body, and end/// +/// blocks. +LoopInfo createEmptyLoop(Value iv, Value ub, Value step, Value initArg, + ConversionPatternRewriter &rewriter, Location loc) { + // Create loop harness. + auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); Block *insertionBlock = rewriter.getInsertionBlock(); Block *headerBlock = rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint()); Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin()); Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin()); - rewriter.setInsertionPointToEnd(insertionBlock); + + Type ivTy = iv.getType(); + Type initArgTy = initArg.getType(); + headerBlock->addArguments({ivTy, initArgTy}, {loc, loc}); + bodyBlock->addArguments({ivTy, initArgTy}, {loc, loc}); + endBlock->addArgument(initArgTy, loc); // Loop header. - rewriter.create(loc, headerBlock, SmallVector{iv}); + rewriter.setInsertionPointToEnd(insertionBlock); + rewriter.create(loc, headerBlock, + SmallVector{iv, initArg}); rewriter.setInsertionPointToStart(headerBlock); - auto b = TritonLLVMOpBuilder(loc, rewriter); - rewriter.create(loc, b.icmp_slt(iv, ub), bodyBlock, - endBlock, SmallVector{iv}); + + auto cond = b.icmp_slt(headerBlock->getArgument(0), ub); + auto headerArgs = headerBlock->getArguments(); + rewriter.create(loc, cond, bodyBlock, headerArgs, endBlock, + headerArgs.drop_front()); rewriter.setInsertionPointToStart(bodyBlock); // Loop body. - auto nextIv = b.add(iv, step); - rewriter.create(loc, headerBlock, SmallVector{nextIv}); + auto nextIV = b.add(bodyBlock->getArgument(0), step); + SmallVector args1{nextIV}; + args1.push_back(bodyBlock->getArgument(1)); + rewriter.create(loc, headerBlock, args1); rewriter.setInsertionPointToStart(endBlock); - return bodyBlock; + return {headerBlock, bodyBlock, endBlock}; } -// Initialize a variable to \p init and return the loaded value. +/// Initializes a variable to a given value and returns the loaded value. +/// - Parameters: +/// - `init`: The initial value to assign to the variable. +/// - Returns: The loaded value of the initialized variable. Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto ptr = rewriter.create(loc, ptr_ty(rewriter.getContext()), @@ -121,7 +150,7 @@ namespace mlir::triton::gpu { enum class CodeGenMode { Unroll, Loop, -} codeGenMode = CodeGenMode::Unroll; +} codeGenMode = CodeGenMode::Loop; LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, ArrayRef, ArrayRef, @@ -260,8 +289,8 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, const unsigned len = acc.size(); Type elemType = acc.front().getType(); auto builder = TritonLLVMOpBuilder(loc, rewriter); - Value vecD = builder.undef(vec_ty(elemType, len)); + Value vecD; Value zero = builder.i32_val(0), one = builder.i32_val(1); for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) @@ -271,10 +300,14 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, Value outerIV = createIV(zero, rewriter, loc); Value outerUB = builder.i32_val(sizePerThread[1]); Value outerStep = builder.i32_val(sizePerThread[2]); - Block *outerBody = - createEmptyLoop(outerIV, outerUB, outerStep, rewriter, loc); + LoopInfo outerLoopInfo = createEmptyLoop(outerIV, outerUB, outerStep, + {vecC}, rewriter, loc); + Block *outerBody = outerLoopInfo.body; + Block *outerEnd = outerLoopInfo.end; + auto outerLatch = cast(outerBody->getTerminator()); + vecD = outerEnd->getArgument(0); auto afterOuterLoop = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointToStart(outerBody); + rewriter.setInsertionPointToStart(outerLoopInfo.body); // Get the values for operand A. SmallVector AElems; @@ -287,9 +320,13 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, Value innerIV = createIV(zero, rewriter, loc); Value innerUB = outerStep; Value innerStep = one; - Block *innerBody = - createEmptyLoop(innerIV, innerUB, innerStep, rewriter, loc); - rewriter.setInsertionPointToStart(innerBody); + Value initArg = + outerBody->getArgument(outerBody->getNumArguments() - 1); + LoopInfo innerLoopInfo = createEmptyLoop(innerIV, innerUB, innerStep, + initArg, rewriter, loc); + Block *innerBody = innerLoopInfo.body; + Block *innerEnd = innerLoopInfo.end; + rewriter.setInsertionPointToStart(innerLoopInfo.body); // Get the values for operand B. SmallVector BElems; @@ -300,9 +337,10 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, } // Get the value for operand C. - // TODO: generate FMA for integer here. Value accIdx = builder.add(builder.mul(innerUB, outerIV), innerIV); - Value acc = builder.extract_element(vecC, accIdx); + Value innerInitArg = + innerBody->getArgument(innerBody->getNumArguments() - 1); + Value acc = builder.extract_element(innerInitArg, accIdx); // Perform the FMAs. for (unsigned k = 0; k < sizePerThread[2]; ++k) { @@ -317,21 +355,34 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, } // Store the result. - vecD = builder.insert_element(vecD, acc, accIdx); + innerInitArg = builder.insert_element(innerInitArg, acc, accIdx); rewriter.restoreInsertionPoint(afterOuterLoop); + + // Pass the result to the next inner loop iteration. + auto innerLatch = cast(innerBody->getTerminator()); + innerLatch->setOperand(innerLatch->getNumOperands() - 1, + innerInitArg); + + // Pass the result of the inner loop to the next outer loop iteration. + Value innerEndArg = innerEnd->getArgument(0); + outerLatch->setOperand(outerLatch->getNumOperands() - 1, innerEndArg); } - // Create a loop to copy vecD into a struct. + // Create a loop to copy the result into a struct. Value ub = builder.i32_val(len); auto structPtr = rewriter.create(loc, ptr_ty(ctx), elemType, ub); Value iv = createIV(zero, rewriter, loc); - Block *body = createEmptyLoop(iv, ub, one, rewriter, loc); + LoopInfo loopInfo = createEmptyLoop(iv, ub, one, vecD, rewriter, loc); + Block *body = loopInfo.body; auto afterLoop = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(body); - Value val = builder.extract_element(vecD, iv); + Value val = builder.extract_element( + body->getArgument(body->getNumArguments() - 1), iv); Value ptr = builder.gep(ptr_ty(ctx), val.getType(), structPtr, iv); rewriter.create(loc, val, ptr); + + // Load the struct and replace the original op. rewriter.restoreInsertionPoint(afterLoop); auto loadVal = rewriter.create(loc, dType, structPtr); rewriter.replaceOp(op, loadVal); From 1366d40e51d024f79201e6cde72866179d6d3e1e Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Tue, 16 Sep 2025 21:45:01 +0000 Subject: [PATCH 06/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- include/triton/Tools/Sys/GetEnv.hpp | 1 + .../DotOpToLLVM/FMADotUtility.cpp | 144 ++++++++++-------- 2 files changed, 80 insertions(+), 65 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index fa299fb36b..942f7f04aa 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -51,6 +51,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32", "TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM", "TRITON_INTEL_ENABLE_INSTR_SCHED", + "TRITON_INTEL_LOWER_DOT_TO_LOOP", "TRITON_INTEL_FAST_MATH", "TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT", "TRITON_INTEL_REDUCE_TRANSPOSE", diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index a8a30d6ab1..5c138ec70e 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -4,8 +4,8 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Value.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -94,38 +94,41 @@ struct LoopInfo { /// blocks. LoopInfo createEmptyLoop(Value iv, Value ub, Value step, Value initArg, ConversionPatternRewriter &rewriter, Location loc) { - // Create loop harness. auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); + + // Create loop blocks. Block *insertionBlock = rewriter.getInsertionBlock(); Block *headerBlock = rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint()); Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin()); Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin()); + // Add arguments to blocks. Type ivTy = iv.getType(); Type initArgTy = initArg.getType(); headerBlock->addArguments({ivTy, initArgTy}, {loc, loc}); bodyBlock->addArguments({ivTy, initArgTy}, {loc, loc}); endBlock->addArgument(initArgTy, loc); - // Loop header. + // Connect insertion block to header block. rewriter.setInsertionPointToEnd(insertionBlock); - rewriter.create(loc, headerBlock, - SmallVector{iv, initArg}); + rewriter.create(loc, headerBlock, ValueRange{iv, initArg}); + + // Build header block. rewriter.setInsertionPointToStart(headerBlock); + Value cond = b.icmp_slt(headerBlock->getArgument(0), ub); + rewriter.create(loc, cond, bodyBlock, + headerBlock->getArguments(), endBlock, + ValueRange{headerBlock->getArgument(1)}); - auto cond = b.icmp_slt(headerBlock->getArgument(0), ub); - auto headerArgs = headerBlock->getArguments(); - rewriter.create(loc, cond, bodyBlock, headerArgs, endBlock, - headerArgs.drop_front()); + // Build body block. rewriter.setInsertionPointToStart(bodyBlock); + Value nextIV = b.add(bodyBlock->getArgument(0), step); + rewriter.create(loc, headerBlock, + ValueRange{nextIV, bodyBlock->getArgument(1)}); - // Loop body. - auto nextIV = b.add(bodyBlock->getArgument(0), step); - SmallVector args1{nextIV}; - args1.push_back(bodyBlock->getArgument(1)); - rewriter.create(loc, headerBlock, args1); + // Set insertion point to end block. rewriter.setInsertionPointToStart(endBlock); return {headerBlock, bodyBlock, endBlock}; @@ -147,15 +150,10 @@ Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) { namespace mlir::triton::gpu { -enum class CodeGenMode { - Unroll, - Loop, -} codeGenMode = CodeGenMode::Loop; - LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, ArrayRef, ArrayRef, ArrayRef, unsigned, Type, - ConversionPatternRewriter &); + ConversionPatternRewriter &, FMAVectorMultiplier &); LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, @@ -216,10 +214,21 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, SmallVector acc = cc; - if (codeGenMode == CodeGenMode::Loop) { + llvm::errs() << "repetitions: " << repetitions[0] << " " << repetitions[1] + << " " << repetitions[2] << "\n"; + llvm::errs() << "sizePerThread: " << sizePerThread[0] << " " + << sizePerThread[1] << " " << sizePerThread[2] << "\n"; + + auto mod = op->getParentOfType(); + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "Module: "; + mod->dumpPretty(); + llvm::errs() << "\n"; + + if (triton::tools::getBoolEnv("TRITON_INTEL_LOWER_DOT_TO_LOOP")) { Type dType = typeConverter->convertType(dTensorTy); return genFMALoop(op, has, hbs, acc, sizePerThread, repetitions, K, dType, - rewriter); + rewriter, multiplier); } for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) @@ -252,89 +261,90 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); rewriter.replaceOp(op, res); + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "Module: "; + mod->dumpPretty(); + llvm::errs() << "\n"; + return success(); } LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, ArrayRef acc, ArrayRef sizePerThread, ArrayRef repetitions, unsigned K, Type dType, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { ModuleOp mod = op->getParentOfType(); MLIRContext *ctx = rewriter.getContext(); Location loc = op.getLoc(); + auto builder = TritonLLVMOpBuilder(loc, rewriter); - // Copy struct into vector for operand A. - SmallVector v1; + // Copy struct into vector for operand A,B.C. + SmallVector aOpVector, bOpVector; for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) - for (unsigned b = 0; b < sizePerThread[0]; ++b) - for (unsigned m = 0; m < sizePerThread[1]; ++m) - for (unsigned k = 0; k < K; ++k) - v1.push_back(has.at({bRep, mRep, b, m, k})); - Value vecA = packLLVector(loc, v1, rewriter); - - // Copy struct into vector for operand B. - SmallVector v2; - for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) - for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) - for (unsigned b = 0; b < sizePerThread[0]; ++b) - for (unsigned n = 0; n < sizePerThread[2]; ++n) - for (unsigned k = 0; k < K; ++k) - v2.push_back(hbs.at({bRep, nRep, b, n, k})); - Value vecB = packLLVector(loc, v2, rewriter); - - // Copy struct into vector for operand C. + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + Value vecA = packLLVector(loc, aOpVector, rewriter); + Value vecB = packLLVector(loc, bOpVector, rewriter); Value vecC = packLLVector(loc, acc, rewriter); - const unsigned len = acc.size(); - Type elemType = acc.front().getType(); - auto builder = TritonLLVMOpBuilder(loc, rewriter); + auto getFragment = [&](Value vec, Value iv, unsigned size) { + SmallVector elems; + for (unsigned i = 0; i < size; ++i) { + Value idx = (i != 0) ? builder.add(iv, builder.i32_val(i)) : iv; + elems.push_back(builder.extract_element(vec, idx)); + } + return elems; + }; Value vecD; Value zero = builder.i32_val(0), one = builder.i32_val(1); + Type elemType = acc.front().getType(); + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) for (unsigned b = 0; b < sizePerThread[0]; ++b) { // Generate the outer loop. - Value outerIV = createIV(zero, rewriter, loc); Value outerUB = builder.i32_val(sizePerThread[1]); Value outerStep = builder.i32_val(sizePerThread[2]); - LoopInfo outerLoopInfo = createEmptyLoop(outerIV, outerUB, outerStep, - {vecC}, rewriter, loc); + LoopInfo outerLoopInfo = + createEmptyLoop(createIV(zero, rewriter, loc), outerUB, outerStep, + {vecC}, rewriter, loc); Block *outerBody = outerLoopInfo.body; Block *outerEnd = outerLoopInfo.end; + Value outerIV = outerBody->getArgument(0); auto outerLatch = cast(outerBody->getTerminator()); vecD = outerEnd->getArgument(0); auto afterOuterLoop = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(outerLoopInfo.body); // Get the values for operand A. - SmallVector AElems; - for (unsigned i = 0; i < sizePerThread[2]; ++i) { - Value idx = builder.add(outerIV, builder.i32_val(i)); - AElems.push_back(builder.extract_element(vecA, idx)); - } + SmallVector AElems = getFragment(vecA, outerIV, K); // Generate the inner loop. - Value innerIV = createIV(zero, rewriter, loc); Value innerUB = outerStep; Value innerStep = one; Value initArg = outerBody->getArgument(outerBody->getNumArguments() - 1); - LoopInfo innerLoopInfo = createEmptyLoop(innerIV, innerUB, innerStep, - initArg, rewriter, loc); + LoopInfo innerLoopInfo = + createEmptyLoop(createIV(zero, rewriter, loc), innerUB, innerStep, + initArg, rewriter, loc); Block *innerBody = innerLoopInfo.body; Block *innerEnd = innerLoopInfo.end; + Value innerIV = innerBody->getArgument(0); rewriter.setInsertionPointToStart(innerLoopInfo.body); // Get the values for operand B. - SmallVector BElems; - for (unsigned j = 0; j < sizePerThread[2]; ++j) { - Value idx = - builder.add(innerIV, builder.i32_val(sizePerThread[2] * j)); - BElems.push_back(builder.extract_element(vecB, idx)); - } + SmallVector BElems = getFragment(vecB, innerIV, K); // Get the value for operand C. Value accIdx = builder.add(builder.mul(innerUB, outerIV), innerIV); @@ -342,8 +352,11 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, innerBody->getArgument(innerBody->getNumArguments() - 1); Value acc = builder.extract_element(innerInitArg, accIdx); +#if 1 // Perform the FMAs. - for (unsigned k = 0; k < sizePerThread[2]; ++k) { + acc = multiplier.multiplyVectors(AElems, BElems, acc); +#else + for (unsigned k = 0; k < K; ++k) { TypeSwitch(elemType) .Case([&](auto) { acc = rewriter.create(loc, AElems[k], @@ -353,6 +366,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, acc = builder.fma(AElems[k], BElems[k], acc); }); } +#endif // Store the result. innerInitArg = builder.insert_element(innerInitArg, acc, accIdx); @@ -369,7 +383,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, } // Create a loop to copy the result into a struct. - Value ub = builder.i32_val(len); + Value ub = builder.i32_val(acc.size()); auto structPtr = rewriter.create(loc, ptr_ty(ctx), elemType, ub); Value iv = createIV(zero, rewriter, loc); @@ -388,7 +402,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, rewriter.replaceOp(op, loadVal); llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "Module after:\n"; + llvm::errs() << "Module: "; mod->dumpPretty(); llvm::errs() << "\n"; From a3c362bf62ea22f8d308be136fae89f5d0bf83d9 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Wed, 17 Sep 2025 19:05:37 +0000 Subject: [PATCH 07/12] WIP: Generate FMA loop - fix bug Signed-off-by: Ettore Tiotto --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index 5c138ec70e..6ba97f7941 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -386,9 +386,10 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, Value ub = builder.i32_val(acc.size()); auto structPtr = rewriter.create(loc, ptr_ty(ctx), elemType, ub); - Value iv = createIV(zero, rewriter, loc); - LoopInfo loopInfo = createEmptyLoop(iv, ub, one, vecD, rewriter, loc); + LoopInfo loopInfo = createEmptyLoop(createIV(zero, rewriter, loc), ub, one, + vecD, rewriter, loc); Block *body = loopInfo.body; + Value iv = body->getArgument(0); auto afterLoop = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(body); Value val = builder.extract_element( From db5b72444c610377c1236923c28c9d61c2f8d127 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Wed, 17 Sep 2025 21:24:41 +0000 Subject: [PATCH 08/12] WIP: Generate FMA loop - fix bug Signed-off-by: Ettore Tiotto --- .../DotOpToLLVM/FMADotUtility.cpp | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index 6ba97f7941..e76f5b77d3 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -218,6 +218,7 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, << " " << repetitions[2] << "\n"; llvm::errs() << "sizePerThread: " << sizePerThread[0] << " " << sizePerThread[1] << " " << sizePerThread[2] << "\n"; + llvm::errs() << "K = " << K << "\n"; auto mod = op->getParentOfType(); llvm::errs() << "at line: " << __LINE__ << "\n"; @@ -279,7 +280,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, Location loc = op.getLoc(); auto builder = TritonLLVMOpBuilder(loc, rewriter); - // Copy struct into vector for operand A,B.C. + // Copy struct into vector for operand A,B,C. SmallVector aOpVector, bOpVector; for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) @@ -298,8 +299,9 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, auto getFragment = [&](Value vec, Value iv, unsigned size) { SmallVector elems; + Value idx = builder.mul(iv, builder.i32_val(size)); for (unsigned i = 0; i < size; ++i) { - Value idx = (i != 0) ? builder.add(iv, builder.i32_val(i)) : iv; + idx = (i != 0) ? builder.add(idx, builder.i32_val(i)) : idx; elems.push_back(builder.extract_element(vec, idx)); } return elems; @@ -316,6 +318,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, // Generate the outer loop. Value outerUB = builder.i32_val(sizePerThread[1]); Value outerStep = builder.i32_val(sizePerThread[2]); + LoopInfo outerLoopInfo = createEmptyLoop(createIV(zero, rewriter, loc), outerUB, outerStep, {vecC}, rewriter, loc); @@ -331,7 +334,7 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, SmallVector AElems = getFragment(vecA, outerIV, K); // Generate the inner loop. - Value innerUB = outerStep; + Value innerUB = builder.i32_val(sizePerThread[2]); Value innerStep = one; Value initArg = outerBody->getArgument(outerBody->getNumArguments() - 1); @@ -352,21 +355,8 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, innerBody->getArgument(innerBody->getNumArguments() - 1); Value acc = builder.extract_element(innerInitArg, accIdx); -#if 1 // Perform the FMAs. acc = multiplier.multiplyVectors(AElems, BElems, acc); -#else - for (unsigned k = 0; k < K; ++k) { - TypeSwitch(elemType) - .Case([&](auto) { - acc = rewriter.create(loc, AElems[k], - BElems[k], acc); - }) - .Case([&](auto) { - acc = builder.fma(AElems[k], BElems[k], acc); - }); - } -#endif // Store the result. innerInitArg = builder.insert_element(innerInitArg, acc, accIdx); From 454ee3ca4ec8770e3ee10beb12db1775b267ead9 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Wed, 17 Sep 2025 21:54:03 +0000 Subject: [PATCH 09/12] WIP: Generate FMA loop - fix bug Signed-off-by: Ettore Tiotto --- .../TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index e76f5b77d3..80527036fd 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -152,7 +152,7 @@ namespace mlir::triton::gpu { LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, ArrayRef, ArrayRef, - ArrayRef, unsigned, Type, + ArrayRef, const unsigned, Type, ConversionPatternRewriter &, FMAVectorMultiplier &); LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, @@ -272,8 +272,8 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, ArrayRef acc, ArrayRef sizePerThread, - ArrayRef repetitions, unsigned K, Type dType, - ConversionPatternRewriter &rewriter, + ArrayRef repetitions, const unsigned K, + Type dType, ConversionPatternRewriter &rewriter, FMAVectorMultiplier &multiplier) { ModuleOp mod = op->getParentOfType(); MLIRContext *ctx = rewriter.getContext(); @@ -301,7 +301,8 @@ LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, SmallVector elems; Value idx = builder.mul(iv, builder.i32_val(size)); for (unsigned i = 0; i < size; ++i) { - idx = (i != 0) ? builder.add(idx, builder.i32_val(i)) : idx; + // TODO: use a increment rather than an add here ? + idx = (i != 0) ? builder.add(idx, builder.i32_val(1)) : idx; elems.push_back(builder.extract_element(vec, idx)); } return elems; From fa1d0536683773918fe9c868e1fa4b1aaed88a45 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 19 Sep 2025 19:05:50 +0000 Subject: [PATCH 10/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- .../DotOpToLLVM/FMADotUtility.cpp | 233 +----------------- third_party/intel/lib/Analysis/DPAS.cpp | 2 + .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 2 + .../lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp | 21 +- .../TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp | 6 +- 5 files changed, 21 insertions(+), 243 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index 80527036fd..d1262be385 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -1,11 +1,8 @@ #include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Value.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/ADT/TypeSwitch.h" +#include using namespace mlir; @@ -77,84 +74,10 @@ ValueTableFMA getValueTableFromStructFMA( return res; } -struct LoopInfo { - Block *header; - Block *body; - Block *end; -}; - -/// Creates an empty loop structure with a header, body, and end block. The -/// loop is initialized with an induction variable and an initial argument. -/// - Parameters: -/// - `iv`: The induction variable (already initialized to the lower bound). -/// - `ub`: The upper bound for the loop. -/// - `step`: The step for the induction variable. -/// - `initArg`: The initial argument passed to the loop. -/// - Returns a `LoopInfo` structure containing the header, body, and end/// -/// blocks. -LoopInfo createEmptyLoop(Value iv, Value ub, Value step, Value initArg, - ConversionPatternRewriter &rewriter, Location loc) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - MLIRContext *ctx = rewriter.getContext(); - - // Create loop blocks. - Block *insertionBlock = rewriter.getInsertionBlock(); - Block *headerBlock = - rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint()); - Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin()); - Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin()); - - // Add arguments to blocks. - Type ivTy = iv.getType(); - Type initArgTy = initArg.getType(); - headerBlock->addArguments({ivTy, initArgTy}, {loc, loc}); - bodyBlock->addArguments({ivTy, initArgTy}, {loc, loc}); - endBlock->addArgument(initArgTy, loc); - - // Connect insertion block to header block. - rewriter.setInsertionPointToEnd(insertionBlock); - rewriter.create(loc, headerBlock, ValueRange{iv, initArg}); - - // Build header block. - rewriter.setInsertionPointToStart(headerBlock); - Value cond = b.icmp_slt(headerBlock->getArgument(0), ub); - rewriter.create(loc, cond, bodyBlock, - headerBlock->getArguments(), endBlock, - ValueRange{headerBlock->getArgument(1)}); - - // Build body block. - rewriter.setInsertionPointToStart(bodyBlock); - Value nextIV = b.add(bodyBlock->getArgument(0), step); - rewriter.create(loc, headerBlock, - ValueRange{nextIV, bodyBlock->getArgument(1)}); - - // Set insertion point to end block. - rewriter.setInsertionPointToStart(endBlock); - - return {headerBlock, bodyBlock, endBlock}; -} - -/// Initializes a variable to a given value and returns the loaded value. -/// - Parameters: -/// - `init`: The initial value to assign to the variable. -/// - Returns: The loaded value of the initialized variable. -Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) { - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto ptr = rewriter.create(loc, ptr_ty(rewriter.getContext()), - init.getType(), b.i32_val(1)); - rewriter.create(loc, init, ptr); - return rewriter.create(loc, init.getType(), ptr); -} - } // namespace namespace mlir::triton::gpu { -LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, - ArrayRef, ArrayRef, - ArrayRef, const unsigned, Type, - ConversionPatternRewriter &, FMAVectorMultiplier &); - LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, @@ -214,24 +137,6 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, SmallVector acc = cc; - llvm::errs() << "repetitions: " << repetitions[0] << " " << repetitions[1] - << " " << repetitions[2] << "\n"; - llvm::errs() << "sizePerThread: " << sizePerThread[0] << " " - << sizePerThread[1] << " " << sizePerThread[2] << "\n"; - llvm::errs() << "K = " << K << "\n"; - - auto mod = op->getParentOfType(); - llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "Module: "; - mod->dumpPretty(); - llvm::errs() << "\n"; - - if (triton::tools::getBoolEnv("TRITON_INTEL_LOWER_DOT_TO_LOOP")) { - Type dType = typeConverter->convertType(dTensorTy); - return genFMALoop(op, has, hbs, acc, sizePerThread, repetitions, K, dType, - rewriter, multiplier); - } - for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) @@ -262,142 +167,6 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); rewriter.replaceOp(op, res); - llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "Module: "; - mod->dumpPretty(); - llvm::errs() << "\n"; - - return success(); -} - -LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, - ArrayRef acc, ArrayRef sizePerThread, - ArrayRef repetitions, const unsigned K, - Type dType, ConversionPatternRewriter &rewriter, - FMAVectorMultiplier &multiplier) { - ModuleOp mod = op->getParentOfType(); - MLIRContext *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - auto builder = TritonLLVMOpBuilder(loc, rewriter); - - // Copy struct into vector for operand A,B,C. - SmallVector aOpVector, bOpVector; - for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) - for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) - for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) - for (unsigned b = 0; b < sizePerThread[0]; ++b) - for (unsigned m = 0; m < sizePerThread[1]; ++m) - for (unsigned n = 0; n < sizePerThread[2]; ++n) - for (unsigned k = 0; k < K; ++k) { - aOpVector.push_back(has.at({bRep, mRep, b, m, k})); - bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); - } - - Value vecA = packLLVector(loc, aOpVector, rewriter); - Value vecB = packLLVector(loc, bOpVector, rewriter); - Value vecC = packLLVector(loc, acc, rewriter); - - auto getFragment = [&](Value vec, Value iv, unsigned size) { - SmallVector elems; - Value idx = builder.mul(iv, builder.i32_val(size)); - for (unsigned i = 0; i < size; ++i) { - // TODO: use a increment rather than an add here ? - idx = (i != 0) ? builder.add(idx, builder.i32_val(1)) : idx; - elems.push_back(builder.extract_element(vec, idx)); - } - return elems; - }; - - Value vecD; - Value zero = builder.i32_val(0), one = builder.i32_val(1); - Type elemType = acc.front().getType(); - - for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) - for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) - for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) - for (unsigned b = 0; b < sizePerThread[0]; ++b) { - // Generate the outer loop. - Value outerUB = builder.i32_val(sizePerThread[1]); - Value outerStep = builder.i32_val(sizePerThread[2]); - - LoopInfo outerLoopInfo = - createEmptyLoop(createIV(zero, rewriter, loc), outerUB, outerStep, - {vecC}, rewriter, loc); - Block *outerBody = outerLoopInfo.body; - Block *outerEnd = outerLoopInfo.end; - Value outerIV = outerBody->getArgument(0); - auto outerLatch = cast(outerBody->getTerminator()); - vecD = outerEnd->getArgument(0); - auto afterOuterLoop = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointToStart(outerLoopInfo.body); - - // Get the values for operand A. - SmallVector AElems = getFragment(vecA, outerIV, K); - - // Generate the inner loop. - Value innerUB = builder.i32_val(sizePerThread[2]); - Value innerStep = one; - Value initArg = - outerBody->getArgument(outerBody->getNumArguments() - 1); - LoopInfo innerLoopInfo = - createEmptyLoop(createIV(zero, rewriter, loc), innerUB, innerStep, - initArg, rewriter, loc); - Block *innerBody = innerLoopInfo.body; - Block *innerEnd = innerLoopInfo.end; - Value innerIV = innerBody->getArgument(0); - rewriter.setInsertionPointToStart(innerLoopInfo.body); - - // Get the values for operand B. - SmallVector BElems = getFragment(vecB, innerIV, K); - - // Get the value for operand C. - Value accIdx = builder.add(builder.mul(innerUB, outerIV), innerIV); - Value innerInitArg = - innerBody->getArgument(innerBody->getNumArguments() - 1); - Value acc = builder.extract_element(innerInitArg, accIdx); - - // Perform the FMAs. - acc = multiplier.multiplyVectors(AElems, BElems, acc); - - // Store the result. - innerInitArg = builder.insert_element(innerInitArg, acc, accIdx); - rewriter.restoreInsertionPoint(afterOuterLoop); - - // Pass the result to the next inner loop iteration. - auto innerLatch = cast(innerBody->getTerminator()); - innerLatch->setOperand(innerLatch->getNumOperands() - 1, - innerInitArg); - - // Pass the result of the inner loop to the next outer loop iteration. - Value innerEndArg = innerEnd->getArgument(0); - outerLatch->setOperand(outerLatch->getNumOperands() - 1, innerEndArg); - } - - // Create a loop to copy the result into a struct. - Value ub = builder.i32_val(acc.size()); - auto structPtr = - rewriter.create(loc, ptr_ty(ctx), elemType, ub); - LoopInfo loopInfo = createEmptyLoop(createIV(zero, rewriter, loc), ub, one, - vecD, rewriter, loc); - Block *body = loopInfo.body; - Value iv = body->getArgument(0); - auto afterLoop = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointToStart(body); - Value val = builder.extract_element( - body->getArgument(body->getNumArguments() - 1), iv); - Value ptr = builder.gep(ptr_ty(ctx), val.getType(), structPtr, iv); - rewriter.create(loc, val, ptr); - - // Load the struct and replace the original op. - rewriter.restoreInsertionPoint(afterLoop); - auto loadVal = rewriter.create(loc, dType, structPtr); - rewriter.replaceOp(op, loadVal); - - llvm::errs() << "at line: " << __LINE__ << "\n"; - llvm::errs() << "Module: "; - mod->dumpPretty(); - llvm::errs() << "\n"; - return success(); } diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp index a20cc53777..6bf5a53331 100644 --- a/third_party/intel/lib/Analysis/DPAS.cpp +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -41,6 +41,7 @@ DPASAnalysis::DPASAnalysis(Operation *root) { DPASAnalysis::Result DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const { + return Result::False; if (funcToDotMap.empty() || dotToDPASEngineMap.empty()) return Result::False; @@ -78,6 +79,7 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const { : Result::False; } + return Result::False; return (threadsPerWarp == minSGSize) ? Result::True : Result::False; } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index c95ad95f43..c8f27bc8f3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -5,6 +5,8 @@ add_triton_library(TritonIntelGPUToLLVM ControlFlowOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/DPAS.cpp + DotOpToLLVM/FMA.cpp + DotOpToLLVM/FMADotUtility.cpp DotOpToLLVM.cpp ElementwiseOpToLLVM.cpp Fp4ToFpOpToLLVM.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp index 6283e68a31..d115aba134 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp @@ -1,21 +1,23 @@ #include "PatternTritonGPUOpToLLVM.h" using namespace mlir; -using namespace mlir::triton; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::intel::DpasEncodingAttr; -namespace fma_details { -LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, +namespace mlir::triton::gpu::intel { +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + +LogicalResult convertDPAS(DotOp op, DotOp::Adaptor adaptor, TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); -} // namespace fma_details +} // namespace mlir::triton::gpu::intel namespace { struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - triton::DotOp>::ConvertTritonGPUOpToLLVMPattern; + using ConvertTritonGPUOpToLLVMPattern::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, @@ -33,13 +35,14 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { if (!isOuter && isa( cast(D.getType()).getEncoding())) { - return fma_details::convertDPAS(op, adaptor, getTypeConverter(), - rewriter); + return triton::gpu::intel::convertDPAS(op, adaptor, getTypeConverter(), + rewriter); } if (isa( cast(D.getType()).getEncoding())) - return convertFMADot(op, adaptor, getTypeConverter(), rewriter); + return triton::gpu::intel::convertFMADot(op, adaptor, getTypeConverter(), + rewriter); llvm::report_fatal_error( "Unsupported DotOp found when converting TritonGPU to LLVM."); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index 71f1a20b9f..188a1fde24 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -406,7 +406,8 @@ class DotOpDPASConversionHelper { } // namespace -namespace fma_details { +namespace mlir::triton::gpu::intel { + LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { @@ -441,4 +442,5 @@ LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, return helper.convertDot(op, adaptor); } -} // namespace fma_details + +} // namespace mlir::triton::gpu::intel From 3a694722f0edcb09b9a09937fe966d7aa04c0b1f Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 19 Sep 2025 19:07:29 +0000 Subject: [PATCH 11/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- .../TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp | 65 +++ .../DotOpToLLVM/FMADotUtility.cpp | 438 ++++++++++++++++++ 2 files changed, 503 insertions(+) create mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp create mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 0000000000..6804036023 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,65 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +namespace { +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { + OpBuilder &builder; + Location loc; + +public: + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) + : builder(builder), loc(loc) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto K = a.size(); + assert(b.size() == K); + Value accum = c; + Type tgtTy = accum.getType(); + for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) { + const auto &aElem = std::get<0>(*it); + const auto &bElem = std::get<1>(*it); + + assert(aElem.getType() == tgtTy); + assert(bElem.getType() == tgtTy); + + // to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM + // type or LLVM dialect-compatible vector of floating point LLVM type, but + // got 'i32' + llvm::TypeSwitch(tgtTy) + .Case([&](auto) { + accum = builder.create(loc, aElem, bElem, accum); + }) + .Case([&](auto) { + accum = builder.create( + loc, builder.create(loc, aElem, bElem), accum); + }); + } + return accum; + } +}; + +} // namespace + +namespace mlir::triton::gpu::intel { + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier); + +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + GenericFMAVectorMultiplier multiplier(rewriter, loc); + return intel::parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); +} +} // namespace mlir::triton::gpu::intel \ No newline at end of file diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp new file mode 100644 index 0000000000..650a9f0e75 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -0,0 +1,438 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { + +/// OperandValueKey structure represents compile time part +/// of spatial coordinates of a value in a tensor. +/// +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be +/// defined as: +/// +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord) +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord) +/// k = kIdx +/// +/// Where: +/// CTABSize, CTANKSize: constants; +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components; +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components. +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } + + void print() const { + llvm::errs() << "[" << bRepIdx << "," << nonKRepIdx << "," << bIdx << "," + << nonKIdx << "," << kIdx << "]"; + } +}; + +} // namespace + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); + } +}; + +namespace { + +using ValueTableFMA = std::unordered_map; + +ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + + llvm::errs() << "key: "; + key.print(); + llvm::errs() << "idx = " << idx << "\n"; + res[key] = elems[idx]; + } + return res; +} + +struct LoopInfo { + Block *header; + Block *body; + Block *end; +}; + +/// Creates an empty loop structure with a header, body, and end block. The +/// loop is initialized with an induction variable and an initial argument. +/// - Parameters: +/// - `iv`: The induction variable (already initialized to the lower bound). +/// - `ub`: The upper bound for the loop. +/// - `step`: The step for the induction variable. +/// - `initArg`: The initial argument passed to the loop. +/// - Returns a `LoopInfo` structure containing the header, body, and end +// blocks. +LoopInfo createEmptyLoop(Value iv, Value ub, Value step, Value initArg, + ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + + // Create loop blocks. + Block *insertionBlock = rewriter.getInsertionBlock(); + Block *headerBlock = + rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint()); + Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin()); + Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin()); + + // Add arguments to blocks. + Type ivTy = iv.getType(); + Type initArgTy = initArg.getType(); + headerBlock->addArguments({ivTy, initArgTy}, {loc, loc}); + bodyBlock->addArguments({ivTy, initArgTy}, {loc, loc}); + endBlock->addArgument(initArgTy, loc); + + // Connect insertion block to header block. + rewriter.setInsertionPointToEnd(insertionBlock); + rewriter.create(loc, headerBlock, ValueRange{iv, initArg}); + + // Build header block. + rewriter.setInsertionPointToStart(headerBlock); + Value cond = b.icmp_slt(headerBlock->getArgument(0), ub); + rewriter.create(loc, cond, bodyBlock, + headerBlock->getArguments(), endBlock, + ValueRange{headerBlock->getArgument(1)}); + + // Build body block. + rewriter.setInsertionPointToStart(bodyBlock); + Value nextIV = b.add(bodyBlock->getArgument(0), step); + auto latch = rewriter.create( + loc, headerBlock, ValueRange{nextIV, bodyBlock->getArgument(1)}); + + // Look at example in: + // + auto noUnrollAttr = StringAttr::get(ctx, "llvm.loop.unroll.disable"); + auto namedAttr = rewriter.getNamedAttr("llvm.loop", noUnrollAttr); + auto arrayAttr = rewriter.getArrayAttr({noUnrollAttr}); + latch->setAttr("llvm.loop", arrayAttr); + + // Set insertion point to end block. + rewriter.setInsertionPointToStart(endBlock); + + return {headerBlock, bodyBlock, endBlock}; +} + +/// Initializes a variable to a given value and returns the loaded value. +/// - Parameters: +/// - `init`: The initial value to assign to the variable. +/// - Returns: The loaded value of the initialized variable. +Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ptr = rewriter.create(loc, ptr_ty(rewriter.getContext()), + init.getType(), b.i32_val(1)); + rewriter.create(loc, init, ptr); + return rewriter.create(loc, init.getType(), ptr); +} + +} // namespace + +namespace mlir::triton::gpu::intel { + +LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &, + ArrayRef, ArrayRef, + ArrayRef, ArrayRef, + ArrayRef, const unsigned, const unsigned, + Type, ConversionPatternRewriter &, + FMAVectorMultiplier &); + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getContigPerThread(dTensorTy); + auto numElemsPerThread = product(sizePerThread); + SmallVector shapePerCTATile; + for (auto [reg, thread, warp] : + llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(), + dLayout.getWarpsPerCTA())) { + shapePerCTATile.push_back(reg * thread * warp); + } + shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile)); + sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread)); + + unsigned K = aShapePerCTA[2]; + + unsigned threadTileShape[3]; + unsigned repetitions[3]; + for (int i = 0; i < 3; ++i) { + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); + } + + auto has = getValueTableFromStructFMA( + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); + auto hbs = getValueTableFromStructFMA( + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); + + SmallVector acc = cc; + + if (triton::tools::getBoolEnv("TRITON_INTEL_LOWER_DOT_TO_LOOP")) { + Type dType = typeConverter->convertType(dTensorTy); + return genFMALoop(op, has, hbs, acc, sizePerThread, repetitions, inRepOrder, + repOrder, numElemsPerThread, K, dType, rewriter, + multiplier); + } + + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + + SmallVector aOpVector; + SmallVector bOpVector; + + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + acc[linearAccumIdx] = multiplier.multiplyVectors( + aOpVector, bOpVector, acc[linearAccumIdx]); + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs, + ArrayRef acc, ArrayRef sizePerThread, + ArrayRef repetitions, + ArrayRef inRepOrder, + ArrayRef repOrder, + const unsigned numElemsPerThread, const unsigned K, + Type dType, ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + MLIRContext *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + auto builder = TritonLLVMOpBuilder(loc, rewriter); + + // Copy structs into vector for operand A, B, C. + SmallVector aOpVector, bOpVector; + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + Value vecA = packLLVector(loc, aOpVector, rewriter); + Value vecB = packLLVector(loc, bOpVector, rewriter); + Value vecC = packLLVector(loc, acc, rewriter); + + auto getFragment = [&](Value vec, unsigned bRep, unsigned mRep, unsigned nRep, + unsigned b, Value outerIV, Value innerIV, + Value outerUB, Value innerUB, unsigned K) { + SmallVector elems; + // TODO: compute idx by also using bRep, nRep, b. + // mRep * M * N * K + m * N * K + n * K + k + Value idx = builder.mul( + builder.mul(builder.mul(builder.i32_val(mRep), outerUB), innerUB), + builder.i32_val(K)); + idx = builder.add( + idx, builder.mul(builder.mul(outerIV, innerUB), builder.i32_val(K))); + idx = builder.add(idx, builder.mul(innerIV, builder.i32_val(K))); + + for (unsigned k = 0; k < K; ++k) { + idx = (k != 0) ? builder.add(idx, builder.i32_val(1)) : idx; + elems.push_back(builder.extract_element(vec, idx)); + } + return elems; + }; + + auto linearize = [&](ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + auto builder = TritonLLVMOpBuilder(loc, rewriter); + Value linear = builder.i32_val(0); + for (unsigned dim : llvm::reverse(order)) { + Value mul = builder.mul(linear, builder.i32_val(shape[dim])); + linear = builder.add(mul, multiDim[dim]); + } + return linear; + }; + + Value zero = builder.i32_val(0), one = builder.i32_val(1); + Type elemType = acc.front().getType(); + Value initArg = vecC; + + Value vecD; + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) { + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) { + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) { + for (unsigned b = 0; b < sizePerThread[0]; ++b) { + // Generate the outer loop. + Value outerUB = builder.i32_val(sizePerThread[1]); + LoopInfo outerLoopInfo = + createEmptyLoop(createIV(zero, rewriter, loc), outerUB, one, + {initArg}, rewriter, loc); + Block *outerBody = outerLoopInfo.body; + Block *outerEnd = outerLoopInfo.end; + Value outerIV = outerBody->getArgument(0); + auto outerLatch = cast(outerBody->getTerminator()); + vecD = outerEnd->getArgument(0); + auto afterOuterLoop = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(outerLoopInfo.body); + + // Generate the inner loop. + Value innerUB = builder.i32_val(sizePerThread[2]); + Value initArg = + outerBody->getArgument(outerBody->getNumArguments() - 1); + LoopInfo innerLoopInfo = + createEmptyLoop(createIV(zero, rewriter, loc), innerUB, one, + initArg, rewriter, loc); + Block *innerBody = innerLoopInfo.body; + Block *innerEnd = innerLoopInfo.end; + Value innerIV = innerBody->getArgument(0); + rewriter.setInsertionPointToStart(innerLoopInfo.body); + + // Get the fragments for operands A and B. + SmallVector AElems = getFragment( + vecA, bRep, mRep, nRep, b, outerIV, innerIV, outerUB, innerUB, K); + SmallVector BElems = getFragment( + vecB, bRep, mRep, nRep, b, outerIV, innerIV, outerUB, innerUB, K); + + // Compute the index into the accumulator. + SmallVector multiDimAccumIdx{builder.i32_val(b), outerIV, + innerIV}; + Value linearInRepIdx = + linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + Value linearAccumIdx = + builder.add(linearInRepIdx, + builder.i32_val(linearRepIdx * numElemsPerThread)); + Value innerInitArg = + innerBody->getArgument(innerBody->getNumArguments() - 1); + + // Extract the element from accumulator. + Value CElem = builder.extract_element(innerInitArg, linearAccumIdx); + + // Perform the FMAs. + CElem = multiplier.multiplyVectors(AElems, BElems, CElem); + + // Store the result. + innerInitArg = + builder.insert_element(innerInitArg, CElem, linearAccumIdx); + rewriter.restoreInsertionPoint(afterOuterLoop); + + // Pass the result to the next inner loop iteration. + auto innerLatch = cast(innerBody->getTerminator()); + innerLatch->setOperand(innerLatch->getNumOperands() - 1, + innerInitArg); + + // Pass the result of the inner loop to the next outer loop iteration. + Value innerEndArg = innerEnd->getArgument(0); + outerLatch->setOperand(outerLatch->getNumOperands() - 1, innerEndArg); + } + initArg = vecD; + } + } + } + + // Create a loop to copy the final result into a struct. + Value UB = builder.i32_val(acc.size()); + auto structPtr = + rewriter.create(loc, ptr_ty(ctx), elemType, UB); + LoopInfo loopInfo = createEmptyLoop(createIV(zero, rewriter, loc), UB, one, + vecD, rewriter, loc); + Block *body = loopInfo.body; + Value iv = body->getArgument(0); + auto afterLoop = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(body); + Value val = builder.extract_element( + body->getArgument(body->getNumArguments() - 1), iv); + Value ptr = builder.gep(ptr_ty(ctx), val.getType(), structPtr, iv); + rewriter.create(loc, val, ptr); + + // Load the struct and replace the original op. + rewriter.restoreInsertionPoint(afterLoop); + auto loadVal = rewriter.create(loc, dType, structPtr); + rewriter.replaceOp(op, loadVal); + + return success(); +} + +} // namespace mlir::triton::gpu::intel From 8dae5adb296a9e0f31af746adf862c94af5d8dc5 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 19 Sep 2025 19:09:41 +0000 Subject: [PATCH 12/12] WIP: Generate FMA loop Signed-off-by: Ettore Tiotto --- lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp | 3 --- third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp | 3 ++- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp index d1262be385..fa2c814722 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -1,8 +1,5 @@ #include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Value.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include using namespace mlir; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp index 6804036023..2a59b5d9a1 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -62,4 +62,5 @@ LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, return intel::parametricConvertFMADot(op, adaptor, typeConverter, rewriter, multiplier); } -} // namespace mlir::triton::gpu::intel \ No newline at end of file + +} // namespace mlir::triton::gpu::intel