diff --git a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td index 935271ab7..a5f20ce0b 100644 --- a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td +++ b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td @@ -16,6 +16,8 @@ def TensorI64 "tensor", "::mlir::TensorType">, BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">; +def ScratchTensor : RankedTensorOf<[I8]>; + def TritonModuleOp : TritonExtOp<"module", [ IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, NoTerminator ]> { @@ -32,6 +34,21 @@ def TritonModuleOp : TritonExtOp<"module", [ // clang-format on } +def ScratchMemoryOp : TritonExtOp<"scratch_memory", [Pure]> { + let summary = "Allocate scratch memory"; + let description = [{ Allocate scratch memory for a kernel. }]; + + let arguments = (ins I64Attr : $alignment); + + let results = (outs ScratchTensor : $result); + + // clang-format off + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; + // clang-format on +} + def TritonCallOp : TritonExtOp<"call", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp new file mode 100644 index 000000000..94e2f1e9d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -0,0 +1,98 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +LogicalResult lowerTritonKernelToKernelCall(ModuleOp mod, + triton_ext::TritonCallOp op) { + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(mod); + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto wrappedMod = funcOp->getParentOfType(); + if (!wrappedMod) { + op->emitError("Failed to find parent built-in module."); + return failure(); + } + + if (!wrappedMod->hasAttr("ttg.shared")) { + op->emitError("No ttg.shared attribute found. Triton Passes must be run " + "before invoking lower-triton pass."); + return failure(); + } + + auto ttModOP = wrappedMod->getParentOfType(); + if (!ttModOP) { + op->emitError("No `triton_ext.module` found!"); + return failure(); + } + ttModOP.setVisibility(SymbolTable::Visibility::Private); + + OpBuilder builder(op); + + auto sharedMemSizeAttr = wrappedMod->getAttrOfType("ttg.shared"); + auto sharedMemSize = sharedMemSizeAttr.getValue().getZExtValue(); + auto shmemOpType = op.getGridx().getType(); + auto shmemOp = stablehlo::ConstantOp::create( + builder, op.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, sharedMemSize))); + + auto kernelCallOp = enzymexla::KernelCallOp::create( + builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(), + op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), + op.getBlockz(), shmemOp, op.getClusterx(), op.getClustery(), + op.getClusterz(), op.getInputs(), op.getBackendConfigAttr(), + op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(), + op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + op.replaceAllUsesWith(kernelCallOp); + op.erase(); + return success(); +} + +struct LowerTritonPass + : public mlir::enzyme::impl::LowerTritonPassBase { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + modOp->walk([&](triton_ext::TritonCallOp op) { + if (failed(lowerTritonKernelToKernelCall(modOp, op))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } +}; diff --git a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp new file mode 100644 index 000000000..7882e12d6 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp @@ -0,0 +1,151 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton-extension-ops" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONEXTENSIONOPSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +struct JITCallScratchMemoryLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzymexla::JITCallOp op, + PatternRewriter &rewriter) const override { + auto inputs = op.getInputs(); + + BitVector rewriteScratchMemoryIdxs(inputs.size(), false); + SmallVector newInputs; + bool hasScratchMemory = false; + for (size_t i = 0; i < inputs.size(); i++) { + if (auto scratchMemoryOp = + inputs[i].getDefiningOp()) { + hasScratchMemory = true; + rewriteScratchMemoryIdxs.set(i); + continue; + } + newInputs.push_back(inputs[i]); + } + + if (!hasScratchMemory) + return failure(); // nothing to do + + // hoist the scratch memory allocation and use gpu.alloc to allocate this + // memory in the jit call function + auto modOp = op->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(modOp); + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto funcOpInterface = dyn_cast(funcOp); + + auto &fnBody = funcOp->getRegion(0).front(); + + for (unsigned idx : rewriteScratchMemoryIdxs.set_bits()) { + rewriter.setInsertionPoint(&fnBody, fnBody.begin()); + auto scratchMemoryOp = + inputs[idx].getDefiningOp(); + auto outTy = + cast(scratchMemoryOp.getResult().getType()); + assert(outTy.getRank() == 1); + + auto outMemrefType = MemRefType::get( + outTy.getShape(), outTy.getElementType(), MemRefLayoutAttrInterface{}, + rewriter.getI64IntegerAttr( + cast(fnBody.getArgument(idx).getType()) + .getAddressSpace())); + + auto allocOp = + memref::AllocOp::create(rewriter, op.getLoc(), outMemrefType, + scratchMemoryOp.getAlignmentAttr()); + auto ptrOp = enzymexla::Memref2PointerOp::create( + rewriter, op.getLoc(), + LLVM::LLVMPointerType::get(rewriter.getContext(), + outMemrefType.getMemorySpaceAsInt()), + allocOp.getResult()); + rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult()); + + // clang-format off + // FIXME: This is producing + // error: 'llvm.call' op operand type mismatch for operand 0: '!llvm.ptr<1>' != '!llvm.ptr' + // see current operation: "llvm.call"(%61, %60) <{CConv = #llvm.cconv, TailCallKind = #llvm.tailcallkind, callee = @mgpuMemFree, fastmathFlags = #llvm.fastmath, op_bundle_sizes = array, operandSegmentSizes = array}> : (!llvm.ptr<1>, !llvm.ptr) -> () + // SmallVector deps; + // Operation *lastUser = ptrOp; + // for (auto u : ptrOp->getUsers()) { + // if (auto gpuLaunchOp = dyn_cast(u)) { + // deps.push_back(gpuLaunchOp.getAsyncToken()); + // } + + // if (lastUser->isBeforeInBlock(u)) { + // lastUser = u; + // } + // } + + // rewriter.setInsertionPointAfter(lastUser); + // gpu::DeallocOp::create(rewriter, op.getLoc(), + // gpu::AsyncTokenType::get(rewriter.getContext()), + // ValueRange(deps), allocOp.getResult()); + // clang-format on + } + + funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs); + + // TODO: to be safe we should rework the other attributes if they are being + // removed... + rewriter.setInsertionPoint(op); + auto newJitCallOp = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), op.getResultTypes(), op.getFn(), newInputs, + op.getBackendConfigAttr(), op.getOperandLayoutsAttr(), + op.getResultLayoutsAttr(), op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + rewriter.replaceOp(op, newJitCallOp); + return success(); + } +}; + +struct LowerTritonExtensionOpsPass + : public mlir::enzyme::impl::LowerTritonExtensionOpsPassBase< + LowerTritonExtensionOpsPass> { + using Base::Base; + + void runOnOperation() override { + auto context = getOperation()->getContext(); + + RewritePatternSet patterns(context); + patterns.add(context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 42622da99..a8abeecd3 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1039,4 +1039,37 @@ def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass< >]; } +def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { + let summary = "Lower Triton to kernel call"; + let dependentDialects = [ + "triton::TritonDialect", + "gpu::GPUDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + "stablehlo::StablehloDialect", + ]; +} + +def TritonAugmentFunctionWithExtraArgumentsPass : Pass< + "triton-augment-function-with-extra-arguments", "mlir::ModuleOp"> { + let dependentDialects = [ + "triton::TritonDialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + ]; +} + +def LowerTritonExtensionOpsPass : Pass<"lower-triton-extension-ops"> { + let dependentDialects = [ + "triton::TritonDialect", + "func::FuncDialect", + "LLVM::LLVMDialect", + "memref::MemRefDialect", + "enzymexla::EnzymeXLADialect", + "enzymexla::triton_ext::TritonExtDialect", + "gpu::GPUDialect", + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp new file mode 100644 index 000000000..0fb2691d5 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp @@ -0,0 +1,145 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "triton-augment-function-with-extra-arguments" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_TRITONAUGMENTFUNCTIONWITHEXTRAARGUMENTSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +// See for description on the extra arguments +// https://github.com/triton-lang/triton/blob/6ac622c57152ce88edd058f11997b5c5e18d096b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp#L12-L25 + +LogicalResult +augmentTritonCallOpWithExtraArguments(ModuleOp mod, + triton_ext::TritonCallOp op) { + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(mod); + auto funcOp = symbolTable.lookupNearestSymbolFrom( + mod, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto fnKind = funcOp->getName().getStringRef(); + if (fnKind != "llvm.func") { + op->emitError("augmentTritonCallOpWithExtraArguments: expected '") + << op.getFn() << "' to be a llvm.func, got: " << fnKind << ". This " + << "means that the pass is being called before tt.func is being " + "lowered to llvm.func"; + return failure(); + } + + if (funcOp.getNumArguments() == op.getInputs().size()) { + return success(); // already augmented + } + + // See NOTE: [Additional Function Arguments] in triton-lang/triton + if (!mlir::triton::isKernel(funcOp)) { + op->emitError("not a kernel function"); + return failure(); + } + + bool hasProfileScratchMemory = + funcOp.getNumArguments() == + op.getInputs().size() + 2; // to support compatibility with old kernels + + if (funcOp.getNumArguments() != + op.getInputs().size() + 1 + hasProfileScratchMemory) { + op->emitError("Expected ") + << (funcOp.getNumArguments() - 1 - hasProfileScratchMemory) + << " arguments, got " << op.getInputs().size(); + return failure(); + } + + auto newInputs = llvm::to_vector(op.getInputs()); + + // global scratch memory + uint64_t gsmNBytes = 0, gsmAlign = 0; + if (auto gsm = funcOp->getAttrOfType( + "ttg.global_scratch_memory_size")) { + gsmNBytes = gsm.getValue().getZExtValue(); + } + if (auto smalign = funcOp->getAttrOfType( + "ttg.global_scratch_memory_alignment")) { + gsmAlign = smalign.getValue().getZExtValue(); + } + + OpBuilder builder(op); + + auto gsmTy = RankedTensorType::get({static_cast(gsmNBytes)}, + builder.getIntegerType(8)); + auto gsm = triton_ext::ScratchMemoryOp::create( + builder, op.getLoc(), gsmTy, builder.getI64IntegerAttr(gsmAlign)); + newInputs.push_back(gsm); + + // profile scratch memory + if (hasProfileScratchMemory) { + uint64_t psmNBytes = 0, psmAlign = 1; + if (auto psm = funcOp->getAttrOfType( + "ttg.profile_scratch_memory_size")) { + psmNBytes = psm.getValue().getZExtValue(); + } + if (auto psmalign = funcOp->getAttrOfType( + "ttg.profile_scratch_memory_alignment")) { + psmAlign = psmalign.getValue().getZExtValue(); + } + + auto psmTy = RankedTensorType::get({static_cast(psmNBytes)}, + builder.getIntegerType(8)); + auto psm = triton_ext::ScratchMemoryOp::create( + builder, op.getLoc(), psmTy, builder.getI64IntegerAttr(psmAlign)); + newInputs.push_back(psm); + } + + auto newCallOp = triton_ext::TritonCallOp::create( + builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(), + op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), + op.getBlockz(), op.getClusterx(), op.getClustery(), op.getClusterz(), + newInputs, op.getBackendConfigAttr(), op.getOperandLayoutsAttr(), + /*resultLayouts*/ nullptr, op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + op.replaceAllUsesWith(newCallOp); + op.erase(); + return success(); +} + +struct TritonAugmentFunctionWithExtraArgumentsPass + : public mlir::enzyme::impl:: + TritonAugmentFunctionWithExtraArgumentsPassBase< + TritonAugmentFunctionWithExtraArgumentsPass> { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + modOp->walk([&](triton_ext::TritonCallOp op) -> WalkResult { + if (failed(augmentTritonCallOpWithExtraArguments(modOp, op))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } +};