From 7ff82434a9a9158e9382161c9a7c0284beadde1c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 10 Nov 2025 23:47:48 -0600 Subject: [PATCH 1/5] feat: augment tt_ext.call with extra arguments --- src/enzyme_ad/jax/Dialect/TritonExt/Ops.td | 17 ++ src/enzyme_ad/jax/Passes/Passes.td | 9 + ...ritonAugmentFunctionWithExtraArguments.cpp | 154 ++++++++++++++++++ 3 files changed, 180 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp diff --git a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td index 935271ab7d..f5db24c593 100644 --- a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td +++ b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td @@ -16,6 +16,8 @@ def TensorI64 "tensor", "::mlir::TensorType">, BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">; +def ScratchTensor : RankedTensorOf<[I8]>; + def TritonModuleOp : TritonExtOp<"module", [ IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, NoTerminator ]> { @@ -32,6 +34,21 @@ def TritonModuleOp : TritonExtOp<"module", [ // clang-format on } +def ScratchMemoryOp : TritonExtOp<"scratch_memory", [ConstantLike, Pure]> { + let summary = "Allocate scratch memory"; + let description = [{ Allocate scratch memory for a kernel. }]; + + let arguments = (ins I64Attr : $alignment); + + let results = (outs ScratchTensor : $result); + + // clang-format off + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; + // clang-format on +} + def TritonCallOp : TritonExtOp<"call", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 42622da991..3a2303913b 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1039,4 +1039,13 @@ def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass< >]; } +def TritonAugmentFunctionWithExtraArgumentsPass : Pass< + "triton-augment-function-with-extra-arguments", "mlir::ModuleOp"> { + let dependentDialects = [ + "triton::TritonDialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp new file mode 100644 index 0000000000..69415ecef8 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp @@ -0,0 +1,154 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "triton-augment-function-with-extra-arguments" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_TRITONAUGMENTFUNCTIONWITHEXTRAARGUMENTSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +// See for description on the extra arguments +// https://github.com/triton-lang/triton/blob/6ac622c57152ce88edd058f11997b5c5e18d096b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp#L12-L25 + +LogicalResult augmentTritonCallOpWithExtraArguments(ModuleOp mod, + triton_ext::TritonCallOp op, + OpBuilder &builder) { + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(mod); + auto funcOp = symbolTable.lookupNearestSymbolFrom( + mod, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto fnKind = funcOp->getName().getStringRef(); + if (fnKind != "llvm.func") { + op->emitError("augmentTritonCallOpWithExtraArguments: expected '") + << op.getFn() << "' to be a llvm.func, got: " << fnKind << ". This " + << "means that the pass is being called before tt.func is being " + "lowered to llvm.func"; + return failure(); + } + + if (funcOp.getNumArguments() == op.getInputs().size()) { + return success(); // already augmented + } + + // See NOTE: [Additional Function Arguments] in triton-lang/triton + if (!mlir::triton::isKernel(funcOp)) { + op->emitError("not a kernel function"); + return failure(); + } + + bool hasProfileScratchMemory = + funcOp.getNumArguments() == + op.getInputs().size() + 2; // to support compatibility with old kernels + + if (funcOp.getNumArguments() != + op.getInputs().size() + 1 + hasProfileScratchMemory) { + op->emitError("Expected ") + << (funcOp.getNumArguments() - 1 - hasProfileScratchMemory) + << " arguments, got " << op.getInputs().size(); + return failure(); + } + + auto newInputs = llvm::to_vector(op.getInputs()); + + // global scratch memory + uint64_t gsmNBytes = 0; + uint64_t gsmAlign = 0; + if (auto gsm = funcOp->getAttrOfType( + "ttg.global_scratch_memory_size")) { + gsmNBytes = gsm.getValue().getZExtValue(); + } + if (auto smalign = funcOp->getAttrOfType( + "ttg.global_scratch_memory_alignment")) { + gsmAlign = smalign.getValue().getZExtValue(); + } + + builder.setInsertionPoint(op); + + auto gsmTy = RankedTensorType::get({static_cast(gsmNBytes)}, + builder.getIntegerType(8)); + auto gsm = triton_ext::ScratchMemoryOp::create( + builder, op.getLoc(), gsmTy, builder.getI64IntegerAttr(gsmAlign)); + newInputs.push_back(gsm); + + // profile scratch memory + if (hasProfileScratchMemory) { + uint64_t psmNBytes = 0; + uint64_t psmAlign = 0; + if (auto psm = funcOp->getAttrOfType( + "ttg.profile_scratch_memory_size")) { + psmNBytes = psm.getValue().getZExtValue(); + } + if (auto psmalign = funcOp->getAttrOfType( + "ttg.profile_scratch_memory_alignment")) { + psmAlign = psmalign.getValue().getZExtValue(); + } + + auto psmTy = RankedTensorType::get({static_cast(psmNBytes)}, + builder.getIntegerType(8)); + auto psm = triton_ext::ScratchMemoryOp::create( + builder, op.getLoc(), psmTy, builder.getI64IntegerAttr(psmAlign)); + newInputs.push_back(psm); + } + + auto newCallOp = triton_ext::TritonCallOp::create( + builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(), + op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), + op.getBlockz(), op.getClusterx(), op.getClustery(), op.getClusterz(), + newInputs, op.getBackendConfigAttr(), op.getOperandLayoutsAttr(), + /*resultLayouts*/ nullptr, op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + op.replaceAllUsesWith(newCallOp); + op.erase(); + return success(); +} + +struct TritonAugmentFunctionWithExtraArgumentsPass + : public mlir::enzyme::impl:: + TritonAugmentFunctionWithExtraArgumentsPassBase< + TritonAugmentFunctionWithExtraArgumentsPass> { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + OpBuilder builder(modOp); + + bool anyFailed = false; + modOp->walk([&](triton_ext::TritonCallOp op) -> WalkResult { + if (failed(augmentTritonCallOpWithExtraArguments(modOp, op, builder))) { + anyFailed = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (anyFailed) { + signalPassFailure(); + } + } +}; From 80964d80524e7b7825a07d4e574a1a4df39282e8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 00:24:46 -0600 Subject: [PATCH 2/5] feat: lower triton call to kernel call --- src/enzyme_ad/jax/Dialect/TritonExt/Ops.td | 2 +- src/enzyme_ad/jax/Passes/LowerTriton.cpp | 98 +++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 12 +++ ...ritonAugmentFunctionWithExtraArguments.cpp | 25 ++--- 4 files changed, 119 insertions(+), 18 deletions(-) create mode 100644 src/enzyme_ad/jax/Passes/LowerTriton.cpp diff --git a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td index f5db24c593..a5f20ce0b0 100644 --- a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td +++ b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.td @@ -34,7 +34,7 @@ def TritonModuleOp : TritonExtOp<"module", [ // clang-format on } -def ScratchMemoryOp : TritonExtOp<"scratch_memory", [ConstantLike, Pure]> { +def ScratchMemoryOp : TritonExtOp<"scratch_memory", [Pure]> { let summary = "Allocate scratch memory"; let description = [{ Allocate scratch memory for a kernel. }]; diff --git a/src/enzyme_ad/jax/Passes/LowerTriton.cpp b/src/enzyme_ad/jax/Passes/LowerTriton.cpp new file mode 100644 index 0000000000..94e2f1e9d1 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTriton.cpp @@ -0,0 +1,98 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +LogicalResult lowerTritonKernelToKernelCall(ModuleOp mod, + triton_ext::TritonCallOp op) { + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(mod); + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto wrappedMod = funcOp->getParentOfType(); + if (!wrappedMod) { + op->emitError("Failed to find parent built-in module."); + return failure(); + } + + if (!wrappedMod->hasAttr("ttg.shared")) { + op->emitError("No ttg.shared attribute found. Triton Passes must be run " + "before invoking lower-triton pass."); + return failure(); + } + + auto ttModOP = wrappedMod->getParentOfType(); + if (!ttModOP) { + op->emitError("No `triton_ext.module` found!"); + return failure(); + } + ttModOP.setVisibility(SymbolTable::Visibility::Private); + + OpBuilder builder(op); + + auto sharedMemSizeAttr = wrappedMod->getAttrOfType("ttg.shared"); + auto sharedMemSize = sharedMemSizeAttr.getValue().getZExtValue(); + auto shmemOpType = op.getGridx().getType(); + auto shmemOp = stablehlo::ConstantOp::create( + builder, op.getLoc(), shmemOpType, + cast(makeAttr(shmemOpType, sharedMemSize))); + + auto kernelCallOp = enzymexla::KernelCallOp::create( + builder, op.getLoc(), op.getResultTypes(), op.getFn(), op.getGridx(), + op.getGridy(), op.getGridz(), op.getBlockx(), op.getBlocky(), + op.getBlockz(), shmemOp, op.getClusterx(), op.getClustery(), + op.getClusterz(), op.getInputs(), op.getBackendConfigAttr(), + op.getOperandLayoutsAttr(), op.getResultLayoutsAttr(), + op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + op.replaceAllUsesWith(kernelCallOp); + op.erase(); + return success(); +} + +struct LowerTritonPass + : public mlir::enzyme::impl::LowerTritonPassBase { + using Base::Base; + + void runOnOperation() override { + auto modOp = getOperation(); + + modOp->walk([&](triton_ext::TritonCallOp op) { + if (failed(lowerTritonKernelToKernelCall(modOp, op))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 3a2303913b..4dbedb834d 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1039,6 +1039,18 @@ def ConvertTritonToTritonGPUPreservingModuleAttributesPass : Pass< >]; } +def LowerTritonPass : Pass<"lower-triton", "mlir::ModuleOp"> { + let summary = "Lower Triton to kernel call"; + let dependentDialects = [ + "triton::TritonDialect", + "gpu::GPUDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + "enzymexla::triton_ext::TritonExtDialect", + "stablehlo::StablehloDialect", + ]; +} + def TritonAugmentFunctionWithExtraArgumentsPass : Pass< "triton-augment-function-with-extra-arguments", "mlir::ModuleOp"> { let dependentDialects = [ diff --git a/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp index 69415ecef8..0fb2691d5a 100644 --- a/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp +++ b/src/enzyme_ad/jax/Passes/TritonAugmentFunctionWithExtraArguments.cpp @@ -30,9 +30,9 @@ using namespace mlir::enzymexla::triton_ext; // See for description on the extra arguments // https://github.com/triton-lang/triton/blob/6ac622c57152ce88edd058f11997b5c5e18d096b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp#L12-L25 -LogicalResult augmentTritonCallOpWithExtraArguments(ModuleOp mod, - triton_ext::TritonCallOp op, - OpBuilder &builder) { +LogicalResult +augmentTritonCallOpWithExtraArguments(ModuleOp mod, + triton_ext::TritonCallOp op) { SymbolTableCollection symbolTable; symbolTable.getSymbolTable(mod); auto funcOp = symbolTable.lookupNearestSymbolFrom( @@ -76,8 +76,7 @@ LogicalResult augmentTritonCallOpWithExtraArguments(ModuleOp mod, auto newInputs = llvm::to_vector(op.getInputs()); // global scratch memory - uint64_t gsmNBytes = 0; - uint64_t gsmAlign = 0; + uint64_t gsmNBytes = 0, gsmAlign = 0; if (auto gsm = funcOp->getAttrOfType( "ttg.global_scratch_memory_size")) { gsmNBytes = gsm.getValue().getZExtValue(); @@ -87,7 +86,7 @@ LogicalResult augmentTritonCallOpWithExtraArguments(ModuleOp mod, gsmAlign = smalign.getValue().getZExtValue(); } - builder.setInsertionPoint(op); + OpBuilder builder(op); auto gsmTy = RankedTensorType::get({static_cast(gsmNBytes)}, builder.getIntegerType(8)); @@ -97,8 +96,7 @@ LogicalResult augmentTritonCallOpWithExtraArguments(ModuleOp mod, // profile scratch memory if (hasProfileScratchMemory) { - uint64_t psmNBytes = 0; - uint64_t psmAlign = 0; + uint64_t psmNBytes = 0, psmAlign = 1; if (auto psm = funcOp->getAttrOfType( "ttg.profile_scratch_memory_size")) { psmNBytes = psm.getValue().getZExtValue(); @@ -136,19 +134,12 @@ struct TritonAugmentFunctionWithExtraArgumentsPass void runOnOperation() override { auto modOp = getOperation(); - OpBuilder builder(modOp); - - bool anyFailed = false; modOp->walk([&](triton_ext::TritonCallOp op) -> WalkResult { - if (failed(augmentTritonCallOpWithExtraArguments(modOp, op, builder))) { - anyFailed = true; + if (failed(augmentTritonCallOpWithExtraArguments(modOp, op))) { + signalPassFailure(); return WalkResult::interrupt(); } return WalkResult::advance(); }); - - if (anyFailed) { - signalPassFailure(); - } } }; From 7fe06e150daf9625556cc13bc5ba31bfa88c7479 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 01:22:13 -0600 Subject: [PATCH 3/5] feat: full lowering --- .../jax/Passes/LowerTritonExtensionOps.cpp | 131 ++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 12 ++ 2 files changed, 143 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp diff --git a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp new file mode 100644 index 0000000000..10a0e44a2e --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp @@ -0,0 +1,131 @@ +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/enzyme_ad/jax/Utils.h" + +#include "llvm/ADT/SmallVector.h" + +#define DEBUG_TYPE "lower-triton-extension-ops" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERTRITONEXTENSIONOPSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; +using namespace mlir::enzymexla::triton_ext; + +struct JITCallScratchMemoryLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzymexla::JITCallOp op, + PatternRewriter &rewriter) const override { + auto inputs = op.getInputs(); + + BitVector rewriteScratchMemoryIdxs(inputs.size(), false); + SmallVector newInputs; + bool hasScratchMemory = false; + for (size_t i = 0; i < inputs.size(); i++) { + if (auto scratchMemoryOp = + inputs[i].getDefiningOp()) { + hasScratchMemory = true; + rewriteScratchMemoryIdxs.set(i); + continue; + } + newInputs.push_back(inputs[i]); + } + + if (!hasScratchMemory) + return failure(); // nothing to do + + // hoist the scratch memory allocation and use gpu.alloc to allocate this + // memory in the jit call function + auto modOp = op->getParentOfType(); + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(modOp); + auto funcOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + if (!funcOp) { + op->emitError("Failed to find function '") << op.getFn() << "' in module"; + return failure(); + } + + auto funcOpInterface = dyn_cast(funcOp); + + auto &fnBody = funcOp->getRegion(0).front(); + rewriter.setInsertionPoint(&fnBody, fnBody.begin()); + + for (unsigned idx : rewriteScratchMemoryIdxs.set_bits()) { + auto scratchMemoryOp = + inputs[idx].getDefiningOp(); + auto outTy = + cast(scratchMemoryOp.getResult().getType()); + assert(outTy.getRank() == 1); + + auto outMemrefType = MemRefType::get( + outTy.getShape(), outTy.getElementType(), MemRefLayoutAttrInterface{}, + rewriter.getI64IntegerAttr( + cast(fnBody.getArgument(idx).getType()) + .getAddressSpace())); + + auto allocOp = + memref::AllocOp::create(rewriter, op.getLoc(), outMemrefType, + scratchMemoryOp.getAlignmentAttr()); + auto ptrOp = enzymexla::Memref2PointerOp::create( + rewriter, op.getLoc(), + LLVM::LLVMPointerType::get(rewriter.getContext(), + outMemrefType.getMemorySpaceAsInt()), + allocOp.getResult()); + rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult()); + + // TODO: dealloc the ops using gpu.dealloc + } + + funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs); + + // TODO: to be safe we should rework the other attributes if they are being + // removed... + rewriter.setInsertionPoint(op); + auto newJitCallOp = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), op.getResultTypes(), op.getFn(), newInputs, + op.getBackendConfigAttr(), op.getOperandLayoutsAttr(), + op.getResultLayoutsAttr(), op.getArgAttrsAttr(), op.getResAttrsAttr(), + op.getOutputOperandAliasesAttr(), op.getXlaSideEffectFreeAttr()); + rewriter.replaceOp(op, newJitCallOp); + return success(); + } +}; + +struct LowerTritonExtensionOpsPass + : public mlir::enzyme::impl::LowerTritonExtensionOpsPassBase< + LowerTritonExtensionOpsPass> { + using Base::Base; + + void runOnOperation() override { + auto context = getOperation()->getContext(); + + RewritePatternSet patterns(context); + patterns.add(context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + } + } +}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 4dbedb834d..a8abeecd3e 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1060,4 +1060,16 @@ def TritonAugmentFunctionWithExtraArgumentsPass : Pass< ]; } +def LowerTritonExtensionOpsPass : Pass<"lower-triton-extension-ops"> { + let dependentDialects = [ + "triton::TritonDialect", + "func::FuncDialect", + "LLVM::LLVMDialect", + "memref::MemRefDialect", + "enzymexla::EnzymeXLADialect", + "enzymexla::triton_ext::TritonExtDialect", + "gpu::GPUDialect", + ]; +} + #endif From c2b5ddea0c934a9982de8266a9a2f947eaa73a93 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 01:37:18 -0600 Subject: [PATCH 4/5] feat: deallocate the memory --- .../jax/Passes/LowerTritonExtensionOps.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp index 10a0e44a2e..d5a0211e33 100644 --- a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp +++ b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp @@ -68,9 +68,9 @@ struct JITCallScratchMemoryLowering auto funcOpInterface = dyn_cast(funcOp); auto &fnBody = funcOp->getRegion(0).front(); - rewriter.setInsertionPoint(&fnBody, fnBody.begin()); for (unsigned idx : rewriteScratchMemoryIdxs.set_bits()) { + rewriter.setInsertionPoint(&fnBody, fnBody.begin()); auto scratchMemoryOp = inputs[idx].getDefiningOp(); auto outTy = @@ -93,7 +93,22 @@ struct JITCallScratchMemoryLowering allocOp.getResult()); rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult()); - // TODO: dealloc the ops using gpu.dealloc + SmallVector deps; + Operation *lastUser = ptrOp; + for (auto u : ptrOp->getUsers()) { + if (auto gpuLaunchOp = dyn_cast(u)) { + deps.push_back(gpuLaunchOp.getAsyncToken()); + } + + if (lastUser->isBeforeInBlock(u)) { + lastUser = u; + } + } + + rewriter.setInsertionPointAfter(lastUser); + gpu::DeallocOp::create(rewriter, op.getLoc(), + gpu::AsyncTokenType::get(rewriter.getContext()), + ValueRange(deps), allocOp.getResult()); } funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs); From fca05768db37f422b10471e6ec99281a8bfba6f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Nov 2025 09:21:12 -0600 Subject: [PATCH 5/5] fix: temporarily disable dealloc --- .../jax/Passes/LowerTritonExtensionOps.cpp | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp index d5a0211e33..7882e12d62 100644 --- a/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp +++ b/src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp @@ -93,22 +93,27 @@ struct JITCallScratchMemoryLowering allocOp.getResult()); rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult()); - SmallVector deps; - Operation *lastUser = ptrOp; - for (auto u : ptrOp->getUsers()) { - if (auto gpuLaunchOp = dyn_cast(u)) { - deps.push_back(gpuLaunchOp.getAsyncToken()); - } - - if (lastUser->isBeforeInBlock(u)) { - lastUser = u; - } - } - - rewriter.setInsertionPointAfter(lastUser); - gpu::DeallocOp::create(rewriter, op.getLoc(), - gpu::AsyncTokenType::get(rewriter.getContext()), - ValueRange(deps), allocOp.getResult()); + // clang-format off + // FIXME: This is producing + // error: 'llvm.call' op operand type mismatch for operand 0: '!llvm.ptr<1>' != '!llvm.ptr' + // see current operation: "llvm.call"(%61, %60) <{CConv = #llvm.cconv, TailCallKind = #llvm.tailcallkind, callee = @mgpuMemFree, fastmathFlags = #llvm.fastmath, op_bundle_sizes = array, operandSegmentSizes = array}> : (!llvm.ptr<1>, !llvm.ptr) -> () + // SmallVector deps; + // Operation *lastUser = ptrOp; + // for (auto u : ptrOp->getUsers()) { + // if (auto gpuLaunchOp = dyn_cast(u)) { + // deps.push_back(gpuLaunchOp.getAsyncToken()); + // } + + // if (lastUser->isBeforeInBlock(u)) { + // lastUser = u; + // } + // } + + // rewriter.setInsertionPointAfter(lastUser); + // gpu::DeallocOp::create(rewriter, op.getLoc(), + // gpu::AsyncTokenType::get(rewriter.getContext()), + // ValueRange(deps), allocOp.getResult()); + // clang-format on } funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs);