From 277d410d0dd340c53d6e1ac87391e5971604f10e Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Tue, 19 Aug 2025 16:03:23 -0400 Subject: [PATCH 1/3] enzymexla.pointer2memref derivative --- .../EnzymeXLAAutoDiffOpInterfaceImpl.cpp | 30 +++++++++++++++++-- .../jax/Implementations/XLADerivatives.h | 2 ++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp index eb1ebfa7d5..8c6e74806f 100644 --- a/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- CHLOAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +//===- EnzymeXLAAutoDiffOpInterfaceImpl.cpp - Interface external model ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file contains the external model implementation of the automatic -// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// differentiation op interfaces for the EnzymeXLA dialect. // //===----------------------------------------------------------------------===// @@ -193,6 +193,31 @@ struct GPUWrapperOpInterfaceReverse MGradientUtilsReverse *gutils) const {} }; +class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel< + Pointer2MemrefRev, enzymexla::Pointer2MemrefOp> { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const { + auto p2m = cast(op); + if (!gutils->isConstantValue(p2m)) { + Value dres = gutils->invertPointerM(p2m.getSource(), builder); + Value shadow = builder.create( + p2m.getLoc(), p2m.getType(), dres); + gutils->setDiffe(p2m, shadow, builder); + } + } +}; } // namespace void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface( @@ -201,6 +226,7 @@ void mlir::enzyme::registerEnzymeXLADialectAutoDiffInterface( registerInterfaces(context); GPUWrapperOp::attachInterface(*context); GPUWrapperOp::attachInterface(*context); + enzymexla::Pointer2MemrefOp::attachInterface(*context); // Register batching interfaces JITCallOp::attachInterface>( diff --git a/src/enzyme_ad/jax/Implementations/XLADerivatives.h b/src/enzyme_ad/jax/Implementations/XLADerivatives.h index abb7b95461..2ee49377cb 100644 --- a/src/enzyme_ad/jax/Implementations/XLADerivatives.h +++ b/src/enzyme_ad/jax/Implementations/XLADerivatives.h @@ -10,6 +10,7 @@ namespace mlir { namespace enzyme { +void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); @@ -18,6 +19,7 @@ void registerTritonDialectAutoDiffInterface(mlir::DialectRegistry ®istry); static inline void registerXLAAutoDiffInterfaces(mlir::DialectRegistry ®istry) { + registerEnzymeXLADialectAutoDiffInterface(registry); registerMHLODialectAutoDiffInterface(registry); registerStableHLODialectAutoDiffInterface(registry); registerCHLODialectAutoDiffInterface(registry); From 2e1f98ab1312f153742c784625988346cfcb20a4 Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Mon, 20 Oct 2025 21:53:14 -0500 Subject: [PATCH 2/3] removal interface for enzymexla.gpu_wrapper --- .../EnzymeXLAAutoDiffOpInterfaceImpl.cpp | 95 ++++++++++++++++++- .../jax/Implementations/XLADerivatives.h | 2 - 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp index 8c6e74806f..341b8420e2 100644 --- a/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/EnzymeXLAAutoDiffOpInterfaceImpl.cpp @@ -15,8 +15,13 @@ #include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h" #include "Enzyme/MLIR/Interfaces/GradientUtils.h" #include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/RegionUtils.h" #include "src/enzyme_ad/jax/Implementations/SHLOGenericBatchOpInterface.h" #include "Dialect/Ops.h" @@ -69,12 +74,92 @@ struct GPUWrapperOpEnzymeOpsRemover if (gradients.empty() && pushedCaches.empty()) return success(); - if (gradients.size()) - return failure(); + llvm::MapVector cachesMap; + for (auto &it : *wrapOp.getBody()) { + Operation *op = ⁢ + if (auto pushOp = dyn_cast(op)) { + CacheInfo info(pushOp.getCache()); + if (cachesMap.contains(pushOp.getValue())) + info = info.merge(cachesMap.lookup(pushOp.getValue()), rewriter); + cachesMap[pushOp.getValue()] = info; + } + } + SmallVector caches = + llvm::map_to_vector(cachesMap, [](auto p) { return std::get<1>(p); }); + + if (caches.empty()) + return success(); + + SetVector visited; + getUsedValuesDefinedAbove(wrapOp.getBodyRegion(), visited); + SmallVector frontier = llvm::map_to_vector( + caches, [](CacheInfo info) { return info.pushedValue(); }); + SetVector opsToMove; + // Traverse backward from pushed values to find operations that the pushed + // value depends on + while (!frontier.empty()) { + Value v = frontier.back(); + Operation *definingOp = v.getDefiningOp(); + frontier.pop_back(); + + if (!definingOp) + continue; + + // Assume allocations and frees are legal to move + if (hasEffect(definingOp) || + hasEffect(definingOp)) { + definingOp->emitError() << "cannot move op with side effects"; + return failure(); + } + opsToMove.insert(definingOp); + + for (Value operand : definingOp->getOperands()) { + if (visited.contains(operand)) + continue; + + frontier.push_back(operand); + visited.insert(operand); + } + } - if (pushedCaches.size()) - return failure(); + // Move the push and dependent values outside of the wrapper + OpBuilder::InsertionGuard guard(rewriter); + IRMapping map; + rewriter.setInsertionPoint(wrapOp); + for (Operation *toMove : llvm::reverse(opsToMove)) { + Operation *cloned = rewriter.clone(*toMove, map); + toMove->replaceAllUsesWith(cloned->getResults()); + + if (auto allocOp = dyn_cast(cloned)) { + // Assume GPU allocations need to be in address space 1 + auto gpuAlloc = gpu::AllocOp::create( + rewriter, allocOp.getLoc(), + *allocOp.getType().clonePtrWith(rewriter.getI64IntegerAttr(1), + std::nullopt), + /*asyncDependencies=*/ValueRange(), allocOp.getDynamicSizes(), + /*symbolOperands=*/ValueRange()); + allocOp.replaceAllUsesWith(gpuAlloc.getResult(0)); + rewriter.eraseOp(allocOp); + } + } + for (auto &info : caches) { + rewriter.moveOpBefore(info.pushOp, wrapOp); + auto revWrapper = info.popOp->getParentOfType(); + assert(revWrapper && "failed to find reverse gpu_wrapper"); + rewriter.moveOpBefore(info.popOp, revWrapper); + + for (auto user : info.popOp.getResult().getUsers()) { + if (isa(user)) { + rewriter.eraseOp(user); + } + } + rewriter.setInsertionPointAfter(revWrapper); + gpu::DeallocOp::create(rewriter, wrapOp.getLoc(), TypeRange(), + info.popOp.getResult()); + } + + return success(); // TODO need to convert to gpu allocations and conversion/copy /* @@ -214,7 +299,7 @@ class Pointer2MemrefRev : public ReverseAutoDiffOpInterface::ExternalModel< Value dres = gutils->invertPointerM(p2m.getSource(), builder); Value shadow = builder.create( p2m.getLoc(), p2m.getType(), dres); - gutils->setDiffe(p2m, shadow, builder); + gutils->setInvertedPointer(p2m, shadow); } } }; diff --git a/src/enzyme_ad/jax/Implementations/XLADerivatives.h b/src/enzyme_ad/jax/Implementations/XLADerivatives.h index 2ee49377cb..abb7b95461 100644 --- a/src/enzyme_ad/jax/Implementations/XLADerivatives.h +++ b/src/enzyme_ad/jax/Implementations/XLADerivatives.h @@ -10,7 +10,6 @@ namespace mlir { namespace enzyme { -void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); @@ -19,7 +18,6 @@ void registerTritonDialectAutoDiffInterface(mlir::DialectRegistry ®istry); static inline void registerXLAAutoDiffInterfaces(mlir::DialectRegistry ®istry) { - registerEnzymeXLADialectAutoDiffInterface(registry); registerMHLODialectAutoDiffInterface(registry); registerStableHLODialectAutoDiffInterface(registry); registerCHLODialectAutoDiffInterface(registry); From bb21263857db32d359450eb608ce7f108cf73c56 Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Wed, 19 Nov 2025 13:33:50 -0600 Subject: [PATCH 3/3] cudaFree fix with addr space --- src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp index 7c8e8e9237..d8489be176 100644 --- a/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertPolygeistToLLVM.cpp @@ -2717,13 +2717,13 @@ class ConvertDeallocOpToGpuRuntimeCallPattern auto i32 = rewriter.getIntegerType(32); auto moduleOp = deallocOp->getParentOfType(); - auto ptr1ty = LLVM::LLVMPointerType::get(rewriter.getContext(), 1); + auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); if (backend == "cuda") { auto one = LLVM::ConstantOp::create(rewriter, loc, i64, rewriter.getI64IntegerAttr(1)); - Type tys[] = {ptr1ty}; + Type tys[] = {ptrty}; auto cudaFreeFn = LLVM::lookupOrCreateFn(rewriter, moduleOp, "cudaFree", tys, i32); if (failed(cudaFreeFn)) { @@ -2731,6 +2731,9 @@ class ConvertDeallocOpToGpuRuntimeCallPattern return failure(); } + if (cast(ptr.getType()).getAddressSpace() != 0) + ptr = LLVM::AddrSpaceCastOp::create(rewriter, loc, ptrty, ptr); + Value args[] = { ptr, }; @@ -2750,8 +2753,6 @@ class ConvertDeallocOpToGpuRuntimeCallPattern }; LLVM::CallOp::create(rewriter, loc, freeFunc.value(), args); } else if (backend.starts_with("xla")) { - auto ptrty = LLVM::LLVMPointerType::get(rewriter.getContext()); - // handle, ptr Type tys[] = {ptrty, ptrty};