From f007977a3b9d9d219380cd85a197bd5548767f23 Mon Sep 17 00:00:00 2001 From: Chaitanya Date: Mon, 18 Aug 2025 18:15:11 +0530 Subject: [PATCH 1/5] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops. (#145464) This PR introduces two new ops in omp dialect, omp.target_allocmem and omp.target_freemem. omp.target_allocmem: Allocates heap memory on device. Will be lowered to omp_target_alloc call in llvm. omp.target_freemem: Deallocates heap memory on device. Will be lowered to omp+target_free call in llvm. Example: %1 = omp.target_allocmem %device : i32, i64 omp.target_freemem %device, %1 : i32, i64 The work in this PR is C-P/inspired from @ivanradanov commit from coexecute implementation: [Add fir omp target alloc and free ops](https://github.com/ivanradanov/llvm-project/commit/be860ac8baf24b8405e6f396c75d7f0d26375de5) [Lower omp_target_{alloc,free} to llvm](https://github.com/ivanradanov/llvm-project/commit/6e2d584dc93ff99bb89adc28c7afbc2b21c46d39) --- flang/include/flang/Optimizer/Support/Utils.h | 33 ++ flang/lib/Optimizer/CodeGen/CodeGen.cpp | 116 ++----- flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp | 49 +++ flang/lib/Optimizer/Dialect/FIROps.cpp | 1 - flang/lib/Optimizer/Support/Utils.cpp | 71 +++++ .../test/Fir/omp_target_allocmem_freemem.fir | 294 ++++++++++++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 94 ++++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 101 ++++++ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 89 ++++++ .../ompenmp-target-allocmem-freemem.mlir | 42 +++ 10 files changed, 805 insertions(+), 85 deletions(-) create mode 100644 flang/test/Fir/omp_target_allocmem_freemem.fir create mode 100644 mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h index 83c936b7dcada..0b31cfea0430a 100644 --- a/flang/include/flang/Optimizer/Support/Utils.h +++ b/flang/include/flang/Optimizer/Support/Utils.h @@ -27,6 +27,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" +#include "flang/Optimizer/CodeGen/TypeConverter.h" + namespace fir { /// Return the integer value of a arith::ConstantOp. inline std::int64_t toInt(mlir::arith::ConstantOp cop) { @@ -198,6 +200,37 @@ std::optional> getComponentLowerBoundsIfNonDefault( fir::RecordType recordType, llvm::StringRef component, mlir::ModuleOp module, const mlir::SymbolTable *symbolTable = nullptr); +/// Generate a LLVM constant value of type `ity`, using the provided offset. +mlir::LLVM::ConstantOp +genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset); + +/// Helper function for generating the LLVM IR that computes the distance +/// in bytes between adjacent elements pointed to by a pointer +/// of type \p ptrTy. The result is returned as a value of \p idxTy integer +/// type. +mlir::Value computeElementDistance(mlir::Location loc, + mlir::Type llvmObjectType, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout); + +// Compute the alloc scale size (constant factors encoded in the array type). +// We do this for arrays without a constant interior or arrays of character with +// dynamic length arrays, since those are the only ones that get decayed to a +// pointer to the element type. +mlir::Value genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy, + mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter); + +/// Perform an extension or truncation as needed on an integer value. Lowering +/// to the specific target may involve some sign-extending or truncation of +/// values, particularly to fit them from abstract box types to the +/// appropriate reified structures. +mlir::Value integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, bool fold = false); } // namespace fir #endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index ecc04a6c9a2be..7e6863e49ce8b 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -85,14 +85,6 @@ static inline mlir::Type getI8Type(mlir::MLIRContext *context) { return mlir::IntegerType::get(context, 8); } -static mlir::LLVM::ConstantOp -genConstantIndex(mlir::Location loc, mlir::Type ity, - mlir::ConversionPatternRewriter &rewriter, - std::int64_t offset) { - auto cattr = rewriter.getI64IntegerAttr(offset); - return rewriter.create(loc, ity, cattr); -} - static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter, mlir::Block *insertBefore) { assert(insertBefore && "expected valid insertion block"); @@ -203,39 +195,6 @@ getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op, TODO(op.getLoc(), "did not find allocation function"); } -// Compute the alloc scale size (constant factors encoded in the array type). -// We do this for arrays without a constant interior or arrays of character with -// dynamic length arrays, since those are the only ones that get decayed to a -// pointer to the element type. -template -static mlir::Value -genAllocationScaleSize(OP op, mlir::Type ity, - mlir::ConversionPatternRewriter &rewriter) { - mlir::Location loc = op.getLoc(); - mlir::Type dataTy = op.getInType(); - auto seqTy = mlir::dyn_cast(dataTy); - fir::SequenceType::Extent constSize = 1; - if (seqTy) { - int constRows = seqTy.getConstantRows(); - const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); - if (constRows != static_cast(shape.size())) { - for (auto extent : shape) { - if (constRows-- > 0) - continue; - if (extent != fir::SequenceType::getUnknownExtent()) - constSize *= extent; - } - } - } - - if (constSize != 1) { - mlir::Value constVal{ - genConstantIndex(loc, ity, rewriter, constSize).getResult()}; - return constVal; - } - return nullptr; -} - namespace { struct DeclareOpConversion : public fir::FIROpConversion { public: @@ -270,7 +229,7 @@ struct AllocaOpConversion : public fir::FIROpConversion { auto loc = alloc.getLoc(); mlir::Type ity = lowerTy().indexType(); unsigned i = 0; - mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult(); + mlir::Value size = fir::genConstantIndex(loc, ity, rewriter, 1).getResult(); mlir::Type firObjType = fir::unwrapRefType(alloc.getType()); mlir::Type llvmObjectType = convertObjectType(firObjType); if (alloc.hasLenParams()) { @@ -302,7 +261,8 @@ struct AllocaOpConversion : public fir::FIROpConversion { << scalarType << " with type parameters"; } } - if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter)) + if (auto scaleSize = fir::genAllocationScaleSize( + alloc.getLoc(), alloc.getInType(), ity, rewriter)) size = rewriter.createOrFold(loc, ity, size, scaleSize); if (alloc.hasShapeOperands()) { @@ -479,7 +439,7 @@ struct BoxIsArrayOpConversion : public fir::FIROpConversion { auto loc = boxisarray.getLoc(); TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType()); mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter); - mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0); + mlir::Value c0 = fir::genConstantIndex(loc, rank.getType(), rewriter, 0); rewriter.replaceOpWithNewOp( boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0); return mlir::success(); @@ -815,7 +775,7 @@ struct ConvertOpConversion : public fir::FIROpConversion { // Do folding for constant inputs. if (auto constVal = fir::getIntIfConstant(op0)) { mlir::Value normVal = - genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); + fir::genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); rewriter.replaceOp(convert, normVal); return mlir::success(); } @@ -828,9 +788,9 @@ struct ConvertOpConversion : public fir::FIROpConversion { } // Compare the input with zero. - mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0); - auto isTrue = rewriter.create( - loc, mlir::LLVM::ICmpPredicate::ne, op0, zero); + mlir::Value zero = fir::genConstantIndex(loc, fromTy, rewriter, 0); + auto isTrue = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::ne, op0, zero); // Zero extend the i1 isTrue result to the required type (unless it is i1 // itself). @@ -1075,21 +1035,6 @@ static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op, return getMallocInModule(mod, op, rewriter, indexType); } -/// Helper function for generating the LLVM IR that computes the distance -/// in bytes between adjacent elements pointed to by a pointer -/// of type \p ptrTy. The result is returned as a value of \p idxTy integer -/// type. -static mlir::Value -computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, - mlir::Type idxTy, - mlir::ConversionPatternRewriter &rewriter, - const mlir::DataLayout &dataLayout) { - llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); - unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); - std::int64_t distance = llvm::alignTo(size, alignment); - return genConstantIndex(loc, idxTy, rewriter, distance); -} - /// Return value of the stride in bytes between adjacent elements /// of LLVM type \p llTy. The result is returned as a value of /// \p idxTy integer type. @@ -1098,7 +1043,7 @@ genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy, mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy, const mlir::DataLayout &dataLayout) { // Create a pointer type and use computeElementDistance(). - return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); + return fir::computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); } namespace { @@ -1117,7 +1062,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion { if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) TODO(loc, "fir.allocmem codegen of derived type with length parameters"); mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); - if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) + if (auto scaleSize = + fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter)) size = rewriter.create(loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands()) size = rewriter.create( @@ -1140,7 +1086,8 @@ struct AllocMemOpConversion : public fir::FIROpConversion { mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy) const { - return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); + return fir::computeElementDistance(loc, llTy, idxTy, rewriter, + getDataLayout()); } }; } // namespace @@ -1324,7 +1271,7 @@ genCUFAllocDescriptor(mlir::Location loc, mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; mlir::Value sizeInBytes = - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); + fir::genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; return rewriter .create(loc, fctTy, RTNAME_STRING(CUFAllocDescriptor), @@ -1580,7 +1527,7 @@ struct EmboxCommonConversion : public fir::FIROpConversion { // representation of derived types with pointer/allocatable components. // This has been seen in hashing algorithms using TRANSFER. mlir::Value zero = - genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); + fir::genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); descriptor = insertField(rewriter, loc, descriptor, {getLenParamFieldId(boxTy), 0}, zero); } @@ -1923,8 +1870,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion { bool hasSlice = !xbox.getSlice().empty(); unsigned sliceOffset = xbox.getSliceOperandIndex(); mlir::Location loc = xbox.getLoc(); - mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0); - mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1); + mlir::Value zero = fir::genConstantIndex(loc, i64Ty, rewriter, 0); + mlir::Value one = fir::genConstantIndex(loc, i64Ty, rewriter, 1); mlir::Value prevPtrOff = one; mlir::Type eleTy = boxTy.getEleTy(); const unsigned rank = xbox.getRank(); @@ -1973,7 +1920,7 @@ struct XEmboxOpConversion : public EmboxCommonConversion { prevDimByteStride = getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams()); } else { - prevDimByteStride = genConstantIndex( + prevDimByteStride = fir::genConstantIndex( loc, i64Ty, rewriter, charTy.getLen() * lowerTy().characterBitsize(charTy) / 8); } @@ -2131,7 +2078,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { if (auto charTy = mlir::dyn_cast(inputEleTy)) { if (charTy.hasConstantLen()) { mlir::Value len = - genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); + fir::genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); lenParams.emplace_back(len); } else { mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair, @@ -2140,8 +2087,8 @@ struct XReboxOpConversion : public EmboxCommonConversion { assert(!isInGlobalOp(rewriter) && "character target in global op must have constant length"); mlir::Value width = - genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); - len = rewriter.create(loc, idxTy, len, width); + fir::genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); + len = mlir::LLVM::SDivOp::create(rewriter, loc, idxTy, len, width); } lenParams.emplace_back(len); } @@ -2194,8 +2141,9 @@ struct XReboxOpConversion : public EmboxCommonConversion { mlir::ConversionPatternRewriter &rewriter) const { mlir::Location loc = rebox.getLoc(); mlir::Value zero = - genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); - mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); + fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); + mlir::Value one = + fir::genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) { mlir::Value extent = std::get<0>(iter.value()); unsigned dim = iter.index(); @@ -2227,7 +2175,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { mlir::Location loc = rebox.getLoc(); mlir::Type byteTy = ::getI8Type(rebox.getContext()); mlir::Type idxTy = lowerTy().indexType(); - mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value zero = fir::genConstantIndex(loc, idxTy, rewriter, 0); // Apply subcomponent and substring shift on base address. if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) { // Cast to inputEleTy* so that a GEP can be used. @@ -2255,7 +2203,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { // and strides. llvm::SmallVector slicedExtents; llvm::SmallVector slicedStrides; - mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1); const bool sliceHasOrigins = !rebox.getShift().empty(); unsigned sliceOps = rebox.getSliceOperandIndex(); unsigned shiftOps = rebox.getShiftOperandIndex(); @@ -2328,7 +2276,7 @@ struct XReboxOpConversion : public EmboxCommonConversion { // which may be OK if all new extents are ones, the stride does not // matter, use one. mlir::Value stride = inputStrides.empty() - ? genConstantIndex(loc, idxTy, rewriter, 1) + ? fir::genConstantIndex(loc, idxTy, rewriter, 1) : inputStrides[0]; for (unsigned i = 0; i < rebox.getShape().size(); ++i) { mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i]; @@ -2563,9 +2511,9 @@ struct XArrayCoorOpConversion unsigned shiftOffset = coor.getShiftOperandIndex(); unsigned sliceOffset = coor.getSliceOperandIndex(); auto sliceOps = coor.getSlice().begin(); - mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value one = fir::genConstantIndex(loc, idxTy, rewriter, 1); mlir::Value prevExt = one; - mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value offset = fir::genConstantIndex(loc, idxTy, rewriter, 0); const bool isShifted = !coor.getShift().empty(); const bool isSliced = !coor.getSlice().empty(); const bool baseIsBoxed = @@ -2895,7 +2843,7 @@ struct CoordinateOpConversion // of lower bound aspects. This both accounts for dynamically sized // types and non contiguous arrays. auto idxTy = lowerTy().indexType(); - mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value off = fir::genConstantIndex(loc, idxTy, rewriter, 0); unsigned arrayDim = arrTy.getDimension(); for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) { mlir::Value stride = @@ -3808,8 +3756,8 @@ struct IsPresentOpConversion : public fir::FIROpConversion { ptr = rewriter.create(loc, ptr, 0); } mlir::LLVM::ConstantOp c0 = - genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); - auto addr = rewriter.create(loc, idxTy, ptr); + fir::genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); + auto addr = mlir::LLVM::PtrToIntOp::create(rewriter, loc, idxTy, ptr); rewriter.replaceOpWithNewOp( isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0); diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 37f1c9f97e1ce..97912bda79b08 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -21,6 +21,7 @@ #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Support/FatalError.h" #include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Optimizer/Support/Utils.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -125,10 +126,58 @@ struct PrivateClauseOpConversion return mlir::success(); } }; + +// Convert FIR type to LLVM without turning fir.box into memory +// reference. +static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, + mlir::Type firType) { + if (auto boxTy = mlir::dyn_cast(firType)) + return converter.convertBoxTypeAsStruct(boxTy); + return converter.convertType(firType); +} + +// FIR Op specific conversion for TargetAllocMemOp +struct TargetAllocMemOpConversion + : public OpenMPFIROpConversion { + using OpenMPFIROpConversion::OpenMPFIROpConversion; + + llvm::LogicalResult + matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Type heapTy = allocmemOp.getAllocatedType(); + mlir::Location loc = allocmemOp.getLoc(); + auto ity = lowerTy().indexType(); + mlir::Type dataTy = fir::unwrapRefType(heapTy); + mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); + if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) + TODO(loc, "omp.target_allocmem codegen of derived type with length " + "parameters"); + mlir::Value size = fir::computeElementDistance( + loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); + if (auto scaleSize = fir::genAllocationScaleSize( + loc, allocmemOp.getInType(), ity, rewriter)) + size = rewriter.create(loc, ity, size, scaleSize); + for (mlir::Value opnd : adaptor.getOperands().drop_front()) + size = rewriter.create( + loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + auto mallocTy = + mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); + if (mallocTyWidth != ity.getIntOrFloatBitWidth()) + size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + rewriter.modifyOpInPlace(allocmemOp, [&]() { + allocmemOp.setInType(rewriter.getI8Type()); + allocmemOp.getTypeparamsMutable().clear(); + allocmemOp.getTypeparamsMutable().append(size); + }); + return mlir::success(); + } +}; } // namespace void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add(converter); patterns.add(converter); + patterns.add(converter); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index b6bf2753b80ce..958fc46c9e41c 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -107,7 +107,6 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { } /// Parser shared by Alloca and Allocmem -/// /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index 5d663e28336c0..c71642ce4e806 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -50,3 +50,74 @@ std::optional> fir::getComponentLowerBoundsIfNonDefault( return componentInfo.getLowerBounds(); return std::nullopt; } + +mlir::LLVM::ConstantOp +fir::genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, + std::int64_t offset) { + auto cattr = rewriter.getI64IntegerAttr(offset); + return rewriter.create(loc, ity, cattr); +} + +mlir::Value +fir::computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, + mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + const mlir::DataLayout &dataLayout) { + llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); + unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); + std::int64_t distance = llvm::alignTo(size, alignment); + return fir::genConstantIndex(loc, idxTy, rewriter, distance); +} + +mlir::Value +fir::genAllocationScaleSize(mlir::Location loc, mlir::Type dataTy, + mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter) { + auto seqTy = mlir::dyn_cast(dataTy); + fir::SequenceType::Extent constSize = 1; + if (seqTy) { + int constRows = seqTy.getConstantRows(); + const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); + if (constRows != static_cast(shape.size())) { + for (auto extent : shape) { + if (constRows-- > 0) + continue; + if (extent != fir::SequenceType::getUnknownExtent()) + constSize *= extent; + } + } + } + + if (constSize != 1) { + mlir::Value constVal{ + fir::genConstantIndex(loc, ity, rewriter, constSize).getResult()}; + return constVal; + } + return nullptr; +} + +mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter, + mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val, bool fold) { + auto valTy = val.getType(); + // If the value was not yet lowered, lower its type so that it can + // be used in getPrimitiveTypeSizeInBits. + if (!mlir::isa(valTy)) + valTy = converter.convertType(valTy); + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); + if (fold) { + if (toSize < fromSize) + return rewriter.createOrFold(loc, ty, val); + if (toSize > fromSize) + return rewriter.createOrFold(loc, ty, val); + } else { + if (toSize < fromSize) + return rewriter.create(loc, ty, val); + if (toSize > fromSize) + return rewriter.create(loc, ty, val); + } + return val; +} diff --git a/flang/test/Fir/omp_target_allocmem_freemem.fir b/flang/test/Fir/omp_target_allocmem_freemem.fir new file mode 100644 index 0000000000000..03eb94acb1ac7 --- /dev/null +++ b/flang/test/Fir/omp_target_allocmem_freemem.fir @@ -0,0 +1,294 @@ +// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s + +// UNSUPPORTED: system-windows +// Disabled on 32-bit targets due to the additional `trunc` opcodes required +// UNSUPPORTED: target-x86 +// UNSUPPORTED: target=sparc-{{.*}} +// UNSUPPORTED: target=sparcel-{{.*}} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_nonchar() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 4, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_nonchar() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, i32 + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalars_nonchar() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 400, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalars_nonchar() -> () { + %device = arith.constant 0 : i32 + %0 = arith.constant 100 : index + %1 = omp.target_allocmem %device : i32, i32, %0 + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_char() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 10, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_char() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<1,10> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_char_kind() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 20, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_char_kind() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<2,10> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_dynchar( +// CHECK-SAME: i32 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_dynchar(%l : i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<1,?>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_scalar_dynchar_kind( +// CHECK-SAME: i32 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 2, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_scalar_dynchar_kind(%l : i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.char<2,?>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 36, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_of_nonchar() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3xi32> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_char() { +// CHECK-NEXT: [[TMP1:%.*]] = call ptr @omp_target_alloc(i64 90, i32 0) +// CHECK-NEXT: [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = inttoptr i64 [[TMP2]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP3]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_of_char() -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>> + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar( +// CHECK-SAME: i32 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 9, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32) + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_nonchar( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 12, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0) +// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_nonchar(%e: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?xi32>, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_nonchar2( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 4, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_nonchar2(%e: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array, %e, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_char( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 60, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0) +// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_char(%e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x!fir.char<2,10>>, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_char2( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 20, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 1, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = call ptr @omp_target_alloc(i64 [[TMP4]], i32 0) +// CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP7]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_char2(%e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array>, %e, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_dynchar( +// CHECK-SAME: i32 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 6, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], [[TMP1]] +// CHECK-NEXT: [[TMP6:%.*]] = mul i64 1, [[TMP5]] +// CHECK-NEXT: [[TMP7:%.*]] = call ptr @omp_target_alloc(i64 [[TMP6]], i32 0) +// CHECK-NEXT: [[TMP8:%.*]] = ptrtoint ptr [[TMP7]] to i64 +// CHECK-NEXT: [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP9]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_dynchar(%l: i32, %e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x!fir.char<2,?>>(%l : i32), %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_dynarray_of_dynchar2( +// CHECK-SAME: i32 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = sext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 2, [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], [[TMP1]] +// CHECK-NEXT: [[TMP6:%.*]] = mul i64 [[TMP5]], [[TMP1]] +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 1, [[TMP6]] +// CHECK-NEXT: [[TMP8:%.*]] = call ptr @omp_target_alloc(i64 [[TMP7]], i32 0) +// CHECK-NEXT: [[TMP9:%.*]] = ptrtoint ptr [[TMP8]] to i64 +// CHECK-NEXT: [[TMP10:%.*]] = inttoptr i64 [[TMP9]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP10]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_dynarray_of_dynchar2(%l: i32, %e : index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array>(%l : i32), %e, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_nonchar( +// CHECK-SAME: i64 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 240, [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP3]], [[TMP1]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 1, [[TMP4]] +// CHECK-NEXT: [[TMP6:%.*]] = call ptr @omp_target_alloc(i64 [[TMP5]], i32 0) +// CHECK-NEXT: [[TMP7:%.*]] = ptrtoint ptr [[TMP6]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP8]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_with_holes_nonchar(%0 : index, %1 : index) -> () { + %device = arith.constant 0 : i32 + %2 = omp.target_allocmem %device : i32, !fir.array<4x?x3x?x5xi32>, %0, %1 + omp.target_freemem %device, %2 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_char( +// CHECK-SAME: i64 [[TMP0:%.*]]) { +// CHECK-NEXT: [[TMP2:%.*]] = mul i64 240, [[TMP0]] +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 1, [[TMP2]] +// CHECK-NEXT: [[TMP4:%.*]] = call ptr @omp_target_alloc(i64 [[TMP3]], i32 0) +// CHECK-NEXT: [[TMP5:%.*]] = ptrtoint ptr [[TMP4]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = inttoptr i64 [[TMP5]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP6]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_with_holes_char(%e: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x4x!fir.char<2,10>>, %e + omp.target_freemem %device, %1 : i32, i64 + return +} + +// CHECK-LABEL: define void @omp_target_allocmem_array_with_holes_dynchar( +// CHECK-SAME: i64 [[TMP0:%.*]], i64 [[TMP1:%.*]]) { +// CHECK-NEXT: [[TMP3:%.*]] = mul i64 24, [[TMP0]] +// CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP3]], [[TMP1]] +// CHECK-NEXT: [[TMP5:%.*]] = mul i64 1, [[TMP4]] +// CHECK-NEXT: [[TMP6:%.*]] = call ptr @omp_target_alloc(i64 [[TMP5]], i32 0) +// CHECK-NEXT: [[TMP7:%.*]] = ptrtoint ptr [[TMP6]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr +// CHECK-NEXT: call void @omp_target_free(ptr [[TMP8]], i32 0) +// CHECK-NEXT: ret void +func.func @omp_target_allocmem_array_with_holes_dynchar(%arg0: index, %arg1: index) -> () { + %device = arith.constant 0 : i32 + %1 = omp.target_allocmem %device : i32, !fir.array<3x?x4x!fir.char<2,?>>(%arg0 : index), %arg1 + omp.target_freemem %device, %1 : i32, i64 + return +} diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index e4f52777d8aa2..21edefe175fa7 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2090,4 +2090,98 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [ ]; } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +def TargetAllocMemOp : OpenMP_Op<"target_allocmem", + [MemoryEffects<[MemAlloc]>, AttrSizedOperandSegments]> { + let summary = "allocate storage on an openmp device for an object of a given type"; + + let description = [{ + Allocates memory on the specified OpenMP device for an object of the given type. + Returns an integer value representing the device pointer to the allocated memory. + The memory is uninitialized after allocation. Operations must be paired with + `omp.target_freemem` to avoid memory leaks. + + * `$device`: The integer ID of the OpenMP device where the memory will be allocated. + * `$in_type`: The type of the object for which memory is being allocated. + For arrays, this can be a static or dynamic array type. + * `$uniq_name`: An optional unique name for the allocated memory. + * `$bindc_name`: An optional name used for C interoperability. + * `$typeparams`: Runtime type parameters for polymorphic or parameterized types. + These are typically integer values that define aspects of a type not fixed at compile time. + * `$shape`: Runtime shape operands for dynamic arrays. + Each operand is an integer value representing the extent of a specific dimension. + + ```mlir + // Allocate a static 3x3 integer vector on device 0 + %device_0 = arith.constant 0 : i32 + %ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32> + // ... use %ptr_static ... + omp.target_freemem %device_0, %ptr_static : i32, i64 + + // Allocate a dynamic 2D Fortran array (fir.array) on device 1 + %device_1 = arith.constant 1 : i32 + %rows = arith.constant 10 : index + %cols = arith.constant 20 : index + %ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array, %rows, %cols : index, index + // ... use %ptr_dynamic ... + omp.target_freemem %device_1, %ptr_dynamic : i32, i64 + ``` + }]; + + let arguments = (ins + Arg:$device, + TypeAttr:$in_type, + OptionalAttr:$uniq_name, + OptionalAttr:$bindc_name, + Variadic:$typeparams, + Variadic:$shape + ); + let results = (outs I64); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Type getAllocatedType(); + }]; +} + +//===----------------------------------------------------------------------===// +// TargetFreeMemOp +//===----------------------------------------------------------------------===// + +def TargetFreeMemOp : OpenMP_Op<"target_freemem", + [MemoryEffects<[MemFree]>]> { + let summary = "free memory on an openmp device"; + + let description = [{ + Deallocates memory on the specified OpenMP device that was previously + allocated by an `omp.target_allocmem` operation. After this operation, the + deallocated memory is in an undefined state and should not be accessed. + It is crucial to ensure that all accesses to the memory region are completed + before `omp.target_freemem` is called to avoid undefined behavior. + + * `$device`: The integer ID of the OpenMP device from which the memory will be freed. + * `$heapref`: The integer value representing the device pointer to the memory + to be deallocated, which was previously returned by `omp.target_allocmem`. + + ```mlir + // Example of allocating and freeing memory on an OpenMP device + %device_id = arith.constant 0 : i32 + %allocated_ptr = omp.target_allocmem %device_id : i32, vector<3x3xi32> + // ... operations using %allocated_ptr on the device ... + omp.target_freemem %device_id, %allocated_ptr : i32, i64 + ``` + }]; + + let arguments = (ins + Arg:$device, + Arg:$heapref + ); + let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index b8f477703933d..503cb9a5ee2b2 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3866,6 +3866,107 @@ LogicalResult ScanOp::verify() { "reduction modifier"); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector operands; + llvm::SmallVector typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. ( : ) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 13befc913b2d9..2faecc179216e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -6066,6 +6066,10 @@ static bool isTargetDeviceOp(Operation *op) { if (mlir::isa(op)) return true; + if (mlir::isa(op) || + mlir::isa(op)) + return true; + if (auto parentFn = op->getParentOfType()) if (auto declareTargetIface = llvm::dyn_cast( @@ -6078,6 +6082,85 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *i64Ty = builder.getInt64Ty(); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *returnType = builder.getPtrTy(0); + llvm::FunctionType *fnType = + llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false); + llvm::Function *func = cast( + llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto allocMemOp = cast(opInst); + if (!allocMemOp) + return failure(); + + // Get "omp_target_alloc" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = allocMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the allocation size. + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + mlir::Type heapTy = allocMemOp.getAllocatedType(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + for (auto typeParam : allocMemOp.getTypeparams()) + allocSize = + builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + // Create call to "omp_target_alloc" with the args as translated llvm values. + llvm::CallInst *call = + builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); + llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty()); + + // Map the result + moduleTranslation.mapValue(allocMemOp.getResult(), resultI64); + return success(); +} + +static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *ptrTy = builder.getPtrTy(0); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *voidTy = builder.getVoidTy(); + llvm::FunctionType *fnType = + llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false); + llvm::Function *func = dyn_cast( + llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto freeMemOp = cast(opInst); + if (!freeMemOp) + return failure(); + + // Get "omp_target_free" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = freeMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the corresponding heapref value in llvm + mlir::Value heapref = freeMemOp.getHeapref(); + llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref); + // Convert heapref int to ptr and call "omp_target_free" + llvm::Value *intToPtr = + builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0)); + builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum}); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -6252,6 +6335,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TargetAllocMemOp) { + return convertTargetAllocMemOp(*op, builder, moduleTranslation); + }) + .Case([&](omp::TargetFreeMemOp) { + return convertTargetFreeMemOp(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir new file mode 100644 index 0000000000000..1bc97609ccff4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/ompenmp-target-allocmem-freemem.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -convert-openmp-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s + +// This file contains MLIR test cases for omp.target_allocmem and omp.target_freemem + +// CHECK-LABEL: test_alloc_free_i64 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 8, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_i64() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, i64 + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_1d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 64, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_1d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} + +// CHECK-LABEL: test_alloc_free_vector_2d_f32 +// CHECK: %[[ALLOC:.*]] = call ptr @omp_target_alloc(i64 1024, i32 0) +// CHECK: %[[PTRTOINT:.*]] = ptrtoint ptr %[[ALLOC]] to i64 +// CHECK: %[[INTTOPTR:.*]] = inttoptr i64 %[[PTRTOINT]] to ptr +// CHECK: call void @omp_target_free(ptr %[[INTTOPTR]], i32 0) +// CHECK: ret void +llvm.func @test_alloc_free_vector_2d_f32() -> () { + %device = llvm.mlir.constant(0 : i32) : i32 + %1 = omp.target_allocmem %device : i32, vector<16x16xf32> + omp.target_freemem %device, %1 : i32, i64 + llvm.return +} From b74fde8ef2b900911b529ce5d8517b6a64f0fa26 Mon Sep 17 00:00:00 2001 From: Chaitanya Date: Mon, 25 Aug 2025 12:20:07 +0530 Subject: [PATCH 2/5] Add workdistribute construct in openMP dialect and in llvm frontend (#154376) This PR adds workdistribute mlir op in omp dialect and also in llvm frontend. The work in this PR is c-p and updated from @ivanradanov commits from coexecute implementation: flang_workdistribute_iwomp_2024 --- llvm/include/llvm/Frontend/OpenMP/OMP.td | 59 ++++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 23 ++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 52 +++++++++ mlir/test/Dialect/OpenMP/invalid.mlir | 107 ++++++++++++++++++ mlir/test/Dialect/OpenMP/ops.mlir | 13 +++ 5 files changed, 254 insertions(+) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 1b94657dfae1e..acb88dd447d2d 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1309,6 +1309,17 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } +def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> { + let association = AS_Block; + let category = CA_Executable; + let languages = [L_Fortran]; +} +def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> { + let leafConstructs = OMP_Workdistribute.leafConstructs; + let association = OMP_Workdistribute.association; + let category = OMP_Workdistribute.category; + let languages = [L_Fortran]; +} //===----------------------------------------------------------------------===// // Definitions of OpenMP compound directives @@ -2452,6 +2463,35 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; + let languages = [L_Fortran]; +} def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { let allowedClauses = [ VersionedClause, @@ -2682,6 +2722,25 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> { let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; + let languages = [L_Fortran]; +} def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let allowedClauses = [ VersionedClause, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 21edefe175fa7..5815e90a01840 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2184,4 +2184,27 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem", let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; } +//===----------------------------------------------------------------------===// +// workdistribute Construct +//===----------------------------------------------------------------------===// + +def WorkdistributeOp : OpenMP_Op<"workdistribute"> { + let summary = "workdistribute directive"; + let description = [{ + workdistribute divides execution of the enclosed structured block into + separate units of work, each executed only once by each + initial thread in the league. + ``` + !$omp target teams + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end target teams + ``` + }]; + let regions = (region AnyRegion:$region); + let hasVerifier = 1; + let assemblyFormat = "$region attr-dict"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 503cb9a5ee2b2..194a061dda162 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3967,6 +3967,58 @@ llvm::LogicalResult omp::TargetAllocMemOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// WorkdistributeOp +//===----------------------------------------------------------------------===// + +LogicalResult WorkdistributeOp::verify() { + // Check that region exists and is not empty + Region ®ion = getRegion(); + if (region.empty()) + return emitOpError("region cannot be empty"); + // Verify single entry point. + Block &entryBlock = region.front(); + if (entryBlock.empty()) + return emitOpError("region must contain a structured block"); + // Verify single exit point. + bool hasTerminator = false; + for (Block &block : region) { + if (isa(block.back())) { + if (hasTerminator) { + return emitOpError("region must have exactly one terminator"); + } + hasTerminator = true; + } + } + if (!hasTerminator) { + return emitOpError("region must be terminated with omp.terminator"); + } + auto walkResult = region.walk([&](Operation *op) -> WalkResult { + // No implicit barrier at end + if (isa(op)) { + return emitOpError( + "explicit barriers are not allowed in workdistribute region"); + } + // Check for invalid nested constructs + if (isa(op)) { + return emitOpError( + "nested parallel constructs not allowed in workdistribute"); + } + if (isa(op)) { + return emitOpError( + "nested teams constructs not allowed in workdistribute"); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + Operation *parentOp = (*this)->getParentOp(); + if (!llvm::dyn_cast(parentOp)) + return emitOpError("workdistribute must be nested under teams"); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 294b2cba5cf78..a34b5b41ca679 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2987,3 +2987,110 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) { } llvm.return } + +// ----- +func.func @invalid_workdistribute_empty_region() -> () { + omp.teams { + // expected-error @below {{region cannot be empty}} + omp.workdistribute { + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_no_terminator() -> () { + omp.teams { + // expected-error @below {{region must be terminated with omp.terminator}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_wrong_terminator() -> () { + omp.teams { + // expected-error @below {{region must be terminated with omp.terminator}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + func.return + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_multiple_terminators() -> () { + omp.teams { + // expected-error @below {{region must have exactly one terminator}} + omp.workdistribute { + %cond = arith.constant true + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + omp.terminator + ^bb2: + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_with_barrier() -> () { + omp.teams { + // expected-error @below {{explicit barriers are not allowed in workdistribute region}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + omp.barrier + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_nested_parallel() -> () { + omp.teams { + // expected-error @below {{nested parallel constructs not allowed in workdistribute}} + omp.workdistribute { + omp.parallel { + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- +// Test: nested teams not allowed in workdistribute +func.func @invalid_workdistribute_nested_teams() -> () { + omp.teams { + // expected-error @below {{nested teams constructs not allowed in workdistribute}} + omp.workdistribute { + omp.teams { + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute() -> () { +// expected-error @below {{workdistribute must be nested under teams}} + omp.workdistribute { + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 47cfc5278a5d0..2518afe784c9a 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3197,3 +3197,16 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) { } return } + +// CHECK-LABEL: func.func @omp_workdistribute +func.func @omp_workdistribute() { + // CHECK: omp.teams + omp.teams { + // CHECK: omp.workdistribute + omp.workdistribute { + omp.terminator + } + omp.terminator + } + return +} From a0f8acaaf6e35184f6efb8141d0407b018efc1c6 Mon Sep 17 00:00:00 2001 From: Chaitanya Date: Mon, 25 Aug 2025 18:37:36 +0530 Subject: [PATCH 3/5] Add parser and semantic support for workdistribute (#154377) This PR adds workdistribute parser and semantic support in flang. The work in this PR is c-p and updated from @ivanradanov commits from coexecute implementation: flang_workdistribute_iwomp_2024 --- .../flang/Semantics/openmp-directive-sets.h | 7 ++ flang/lib/Parser/openmp-parsers.cpp | 7 ++ flang/lib/Parser/unparse.cpp | 9 ++ flang/lib/Semantics/check-omp-structure.cpp | 97 +++++++++++++++++++ flang/lib/Semantics/check-omp-structure.h | 1 + flang/lib/Semantics/resolve-directives.cpp | 8 +- flang/test/Parser/OpenMP/workdistribute.f90 | 27 ++++++ .../Semantics/OpenMP/workdistribute01.f90 | 16 +++ .../Semantics/OpenMP/workdistribute02.f90 | 34 +++++++ .../Semantics/OpenMP/workdistribute03.f90 | 34 +++++++ .../Semantics/OpenMP/workdistribute04.f90 | 15 +++ 11 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 flang/test/Parser/OpenMP/workdistribute.f90 create mode 100644 flang/test/Semantics/OpenMP/workdistribute01.f90 create mode 100644 flang/test/Semantics/OpenMP/workdistribute02.f90 create mode 100644 flang/test/Semantics/OpenMP/workdistribute03.f90 create mode 100644 flang/test/Semantics/OpenMP/workdistribute04.f90 diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index dd610c9702c28..35b29dca77333 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_workdistribute, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{ Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, Directive::OMPD_teams_loop, + Directive::OMPD_teams_workdistribute, }; static const OmpDirectiveSet bottomTeamsSet{ @@ -187,6 +189,7 @@ static const OmpDirectiveSet allTeamsSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_workdistribute, } | topTeamsSet, }; @@ -230,6 +233,9 @@ static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_taskgroup, Directive::OMPD_teams, Directive::OMPD_workshare, + Directive::OMPD_target_teams_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_workdistribute, }; static const OmpDirectiveSet loopConstructSet{ @@ -376,6 +382,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{ }; static const OmpDirectiveSet nestedTeamsAllowedSet{ + Directive::OMPD_workdistribute, Directive::OMPD_distribute, Directive::OMPD_distribute_parallel_do, Directive::OMPD_distribute_parallel_do_simd, diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index d70aaab82cbab..24b9c8790a3fa 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1567,11 +1567,16 @@ TYPE_PARSER( "TARGET DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET_DATA" >> pure(llvm::omp::Directive::OMPD_target_data), "TARGET PARALLEL" >> pure(llvm::omp::Directive::OMPD_target_parallel), + "TARGET TEAMS WORKDISTRIBUTE" >> + pure(llvm::omp::Directive::OMPD_target_teams_workdistribute), "TARGET TEAMS" >> pure(llvm::omp::Directive::OMPD_target_teams), "TARGET" >> pure(llvm::omp::Directive::OMPD_target), "TASK"_id >> pure(llvm::omp::Directive::OMPD_task), "TASKGROUP" >> pure(llvm::omp::Directive::OMPD_taskgroup), + "TEAMS WORKDISTRIBUTE" >> + pure(llvm::omp::Directive::OMPD_teams_workdistribute), "TEAMS" >> pure(llvm::omp::Directive::OMPD_teams), + "WORKDISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_workdistribute), "WORKSHARE" >> pure(llvm::omp::Directive::OMPD_workshare)))) TYPE_PARSER(sourced(construct( @@ -1729,6 +1734,8 @@ TYPE_PARSER(sourced( TYPE_PARSER(construct( Parser{} / endOmpLine, block, Parser{} / endOmpLine)) +#define MakeBlockConstruct(dir) \ + construct(OmpBlockConstructParser{dir}) // OMP SECTIONS Directive TYPE_PARSER(construct(first( diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index fbe89c668fc13..7e5945d0c999b 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2561,6 +2561,15 @@ class UnparseVisitor { case llvm::omp::Directive::OMPD_workshare: Word("WORKSHARE "); break; + case llvm::omp::Directive::OMPD_workdistribute: + Word("WORKDISTRIBUTE "); + break; + case llvm::omp::Directive::OMPD_teams_workdistribute: + Word("TEAMS WORKDISTRIBUTE "); + break; + case llvm::omp::Directive::OMPD_target_teams_workdistribute: + Word("TARGET TEAMS WORKDISTRIBUTE "); + break; default: // Nothing to be done break; diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index 0bf2b7ee71f42..093fc7171b555 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Frontend/OpenMP/OMP.h" @@ -141,6 +142,64 @@ class OmpWorkshareBlockChecker { parser::CharBlock source_; }; +// 'OmpWorkdistributeBlockChecker' is used to check the validity of the +// assignment statements and the expressions enclosed in an OpenMP +// WORKDISTRIBUTE construct +class OmpWorkdistributeBlockChecker { +public: + OmpWorkdistributeBlockChecker( + SemanticsContext &context, parser::CharBlock source) + : context_{context}, source_{source} {} + + template bool Pre(const T &) { return true; } + template void Post(const T &) {} + + bool Pre(const parser::AssignmentStmt &assignment) { + const auto &var{std::get(assignment.t)}; + const auto &expr{std::get(assignment.t)}; + const auto *lhs{GetExpr(context_, var)}; + const auto *rhs{GetExpr(context_, expr)}; + if (lhs && rhs) { + Tristate isDefined{semantics::IsDefinedAssignment( + lhs->GetType(), lhs->Rank(), rhs->GetType(), rhs->Rank())}; + if (isDefined == Tristate::Yes) { + context_.Say(expr.source, + "Defined assignment statement is not allowed in a WORKDISTRIBUTE construct"_err_en_US); + } + } + return true; + } + + bool Pre(const parser::Expr &expr) { + if (const auto *e{GetExpr(context_, expr)}) { + if (!e) + return false; + for (const Symbol &symbol : evaluate::CollectSymbols(*e)) { + const Symbol &root{GetAssociationRoot(symbol)}; + if (IsFunction(root)) { + std::vector attrs; + if (!IsElementalProcedure(root)) { + attrs.push_back("non-ELEMENTAL"); + } + if (root.attrs().test(Attr::IMPURE)) { + attrs.push_back("IMPURE"); + } + std::string attrsStr = + attrs.empty() ? "" : " " + llvm::join(attrs, ", "); + context_.Say(expr.source, + "User defined%s function '%s' is not allowed in a WORKDISTRIBUTE construct"_err_en_US, + attrsStr, root.name()); + } + } + } + return false; + } + +private: + SemanticsContext &context_; + parser::CharBlock source_; +}; + // `OmpUnitedTaskDesignatorChecker` is used to check if the designator // can appear within the TASK construct class OmpUnitedTaskDesignatorChecker { @@ -819,6 +878,12 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { "TARGET construct with nested TEAMS region contains statements or " "directives outside of the TEAMS construct"_err_en_US); } + if (GetContext().directive == llvm::omp::Directive::OMPD_workdistribute && + GetContextParent().directive != llvm::omp::Directive::OMPD_teams) { + context_.Say(parser::FindSourceLocation(x), + "%s region can only be strictly nested within TEAMS region"_err_en_US, + ContextDirectiveAsFortran()); + } } CheckNoBranching(block, beginDir.v, beginDir.source); @@ -900,6 +965,17 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { HasInvalidWorksharingNesting( beginDir.source, llvm::omp::nestedWorkshareErrSet); break; + case llvm::omp::OMPD_workdistribute: + if (!CurrentDirectiveIsNested()) { + context_.Say(beginDir.source, + "A WORKDISTRIBUTE region must be nested inside TEAMS region only."_err_en_US); + } + CheckWorkdistributeBlockStmts(block, beginDir.source); + break; + case llvm::omp::OMPD_teams_workdistribute: + case llvm::omp::OMPD_target_teams_workdistribute: + CheckWorkdistributeBlockStmts(block, beginDir.source); + break; case llvm::omp::Directive::OMPD_scope: case llvm::omp::Directive::OMPD_single: // TODO: This check needs to be extended while implementing nesting of @@ -4385,6 +4461,27 @@ void OmpStructureChecker::CheckWorkshareBlockStmts( } } +void OmpStructureChecker::CheckWorkdistributeBlockStmts( + const parser::Block &block, parser::CharBlock source) { + unsigned version{context_.langOptions().OpenMPVersion}; + unsigned since{60}; + if (version < since) + context_.Say(source, + "WORKDISTRIBUTE construct is not allowed in %s, %s"_err_en_US, + ThisVersion(version), TryVersion(since)); + + OmpWorkdistributeBlockChecker ompWorkdistributeBlockChecker{context_, source}; + + for (auto it{block.begin()}; it != block.end(); ++it) { + if (parser::Unwrap(*it)) { + parser::Walk(*it, ompWorkdistributeBlockChecker); + } else { + context_.Say(source, + "The structured block in a WORKDISTRIBUTE construct may consist of only SCALAR or ARRAY assignments"_err_en_US); + } + } +} + void OmpStructureChecker::CheckIfContiguous(const parser::OmpObject &object) { if (auto contig{IsContiguous(context_, object)}; contig && !*contig) { const parser::Name *name{GetObjectName(object)}; diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index 6a877a5d0a7c0..637c1a4b52fda 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -242,6 +242,7 @@ class OmpStructureChecker llvmOmpClause clause, const parser::OmpObjectList &ompObjectList); bool CheckTargetBlockOnlyTeams(const parser::Block &); void CheckWorkshareBlockStmts(const parser::Block &, parser::CharBlock); + void CheckWorkdistributeBlockStmts(const parser::Block &, parser::CharBlock); void CheckIteratorRange(const parser::OmpIteratorSpecifier &x); void CheckIteratorModifier(const parser::OmpIterator &x); diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 151f4ccae634e..42d24e703889e 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1680,10 +1680,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_teams_workdistribute: PushContext(beginDir.source, beginDir.v); break; default: @@ -1713,9 +1716,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: - case llvm::omp::Directive::OMPD_target_parallel: { + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: + case llvm::omp::Directive::OMPD_teams_workdistribute: { bool hasPrivate; for (const auto *allocName : allocateNames_) { hasPrivate = false; diff --git a/flang/test/Parser/OpenMP/workdistribute.f90 b/flang/test/Parser/OpenMP/workdistribute.f90 new file mode 100644 index 0000000000000..61c91cb47cceb --- /dev/null +++ b/flang/test/Parser/OpenMP/workdistribute.f90 @@ -0,0 +1,27 @@ +!RUN: %flang_fc1 -fdebug-unparse -fopenmp -fopenmp-version=60 %s | FileCheck --ignore-case --check-prefix="UNPARSE" %s +!RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp -fopenmp-version=60 %s | FileCheck --check-prefix="PARSE-TREE" %s + +!UNPARSE: SUBROUTINE teams_workdistribute +!UNPARSE: USE :: iso_fortran_env +!UNPARSE: REAL(KIND=4_4) a +!UNPARSE: REAL(KIND=4_4), DIMENSION(10_4) :: x +!UNPARSE: REAL(KIND=4_4), DIMENSION(10_4) :: y +!UNPARSE: !$OMP TEAMS WORKDISTRIBUTE +!UNPARSE: y=a*x+y +!UNPARSE: !$OMP END TEAMS WORKDISTRIBUTE +!UNPARSE: END SUBROUTINE teams_workdistribute + +!PARSE-TREE: | | | OmpBeginBlockDirective +!PARSE-TREE: | | | | OmpBlockDirective -> llvm::omp::Directive = teams workdistribute +!PARSE-TREE: | | | OmpEndBlockDirective +!PARSE-TREE: | | | | OmpBlockDirective -> llvm::omp::Directive = teams workdistribute + +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute01.f90 b/flang/test/Semantics/OpenMP/workdistribute01.f90 new file mode 100644 index 0000000000000..f7e36976dfb65 --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute01.f90 @@ -0,0 +1,16 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=60 +! OpenMP Version 6.0 +! workdistribute Construct +! Invalid do construct inside !$omp workdistribute + +subroutine workdistribute() + integer n, i + !ERROR: A WORKDISTRIBUTE region must be nested inside TEAMS region only. + !ERROR: The structured block in a WORKDISTRIBUTE construct may consist of only SCALAR or ARRAY assignments + !$omp workdistribute + do i = 1, n + print *, "omp workdistribute" + end do + !$omp end workdistribute + +end subroutine workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute02.f90 b/flang/test/Semantics/OpenMP/workdistribute02.f90 new file mode 100644 index 0000000000000..6de3a55f545b5 --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute02.f90 @@ -0,0 +1,34 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=60 +! OpenMP Version 6.0 +! workdistribute Construct +! The !omp workdistribute construct must not contain any user defined +! function calls unless the function is ELEMENTAL. + +module my_mod + contains + integer function my_func() + my_func = 10 + end function my_func + + impure integer function impure_my_func() + impure_my_func = 20 + end function impure_my_func + + impure elemental integer function impure_ele_my_func() + impure_ele_my_func = 20 + end function impure_ele_my_func +end module my_mod + +subroutine workdistribute(aa, bb, cc, n) + use my_mod + integer n + real aa(n), bb(n), cc(n) + !$omp teams + !$omp workdistribute + !ERROR: User defined non-ELEMENTAL function 'my_func' is not allowed in a WORKDISTRIBUTE construct + aa = my_func() + aa = bb * cc + !$omp end workdistribute + !$omp end teams + +end subroutine workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute03.f90 b/flang/test/Semantics/OpenMP/workdistribute03.f90 new file mode 100644 index 0000000000000..828170a016ed2 --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute03.f90 @@ -0,0 +1,34 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=60 +! OpenMP Version 6.0 +! workdistribute Construct +! All array assignments, scalar assignments, and masked array assignments +! must be intrinsic assignments. + +module defined_assign + interface assignment(=) + module procedure work_assign + end interface + + contains + subroutine work_assign(a,b) + integer, intent(out) :: a + logical, intent(in) :: b(:) + end subroutine work_assign +end module defined_assign + +program omp_workdistribute + use defined_assign + + integer :: a, aa(10), bb(10) + logical :: l(10) + l = .TRUE. + + !$omp teams + !$omp workdistribute + !ERROR: Defined assignment statement is not allowed in a WORKDISTRIBUTE construct + a = l + aa = bb + !$omp end workdistribute + !$omp end teams + +end program omp_workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute04.f90 b/flang/test/Semantics/OpenMP/workdistribute04.f90 new file mode 100644 index 0000000000000..d407e8a073ae4 --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute04.f90 @@ -0,0 +1,15 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp -fopenmp-version=50 +! OpenMP Version 6.0 +! workdistribute Construct +! Unsuported OpenMP version + +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !ERROR: WORKDISTRIBUTE construct is not allowed in OpenMP v5.0, try -fopenmp-version=60 + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute From 31535a39c5ac6121ac6c4c8912920e4d7ba42ea2 Mon Sep 17 00:00:00 2001 From: Chaitanya Date: Tue, 26 Aug 2025 09:30:21 +0530 Subject: [PATCH 4/5] Add Lowering to omp mlir for workdistribute construct (#154378) This PR adds lowering of workdistribute construct in flang to omp mlir dialect workdistribute op. The work in this PR is c-p and updated from @ivanradanov commits from coexecute implementation: flang_workdistribute_iwomp_2024 --- flang/lib/Lower/OpenMP/OpenMP.cpp | 23 ++++++++++++++++- flang/test/Lower/OpenMP/workdistribute.f90 | 30 ++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 flang/test/Lower/OpenMP/workdistribute.f90 diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 985920a72a919..ad2c7b92d677e 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -618,6 +618,13 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; + case OMPD_teams_workdistribute: + cp.processThreadLimit(stmtCtx, hostInfo->ops); + [[fallthrough]]; + case OMPD_target_teams_workdistribute: + cp.processNumTeams(stmtCtx, hostInfo->ops); + break; + // Standalone 'target' case. case OMPD_target: { processSingleNestedIf( @@ -2631,6 +2638,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } +static mlir::omp::WorkdistributeOp genWorkdistributeOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + return genOpWithBody( + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, + llvm::omp::Directive::OMPD_workdistribute), + queue, item); +} + //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3262,7 +3280,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter, TODO(loc, "Unhandled loop directive (" + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); } - // case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_workdistribute: + newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, + item); + break; case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90 new file mode 100644 index 0000000000000..7a938b59b8094 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute.f90 @@ -0,0 +1,30 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + integer :: aa(10), bb(10) + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp target teams workdistribute + aa = bb + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp teams workdistribute + y = a * x + y + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end teams workdistribute +end subroutine teams_workdistribute From e029fc7468abcff4ac5a40a9fff978275342255f Mon Sep 17 00:00:00 2001 From: Chaitanya Date: Sat, 18 Oct 2025 07:56:32 +0530 Subject: [PATCH 5/5] Implement workdistribute construct lowering (#140523) This PR introduces a new pass "lower-workdistribute" Fortran array statements are lowered to fir as fir.do_loop unordered. "lower-workdistribute" pass works mainly on identifying "fir.do_loop unordered" that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and lowers it to target{teams{parallel{wsloop{loop_nest}}}}. It hoists all the other ops outside target region. Relaces heap allocation on target with omp.target_allocmem and deallocation with omp.target_freemem from host. Also replaces runtime function "Assign" with omp.target_memcpy from host. This pass implements following rewrites and optimisations: - **FissionWorkdistribute**: finds the parallelizable ops within teams {workdistribute} region and moves them to their own teams{workdistribute} region. - **WorkdistributeRuntimeCallLower**: finds the FortranAAssign calls nested in teams {workdistribute{}} and lowers it to unordered do loop if src is scalar and dest is array. Other runtime calls are not handled currently. - **WorkdistributeDoLower**: finds the fir.do_loop unoredered nested in teams {workdistribute{fir.do_loop unoredered}} and lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. - **TeamsWorkdistributeToSingle**: hoists all the ops inside teams {workdistribute{}} before teams op. The work in this PR is C-P and updated from @ivanradanov commits from coexecute implementation: [flang_workdistribute_iwomp_2024](https://github.com/ivanradanov/llvm-project/commits/flang_workdistribute_iwomp_2024) Paper related to this work by @ivanradanov ["Automatic Parallelization and OpenMP Offloadingof Fortran Array Notation"](https://www.osti.gov/servlets/purl/[2449728](https://www.osti.gov/servlets/purl/2449728)) --- .../include/flang/Optimizer/OpenMP/Passes.td | 4 + flang/lib/Optimizer/OpenMP/CMakeLists.txt | 1 + .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 1852 +++++++++++++++++ flang/lib/Optimizer/Passes/Pipelines.cpp | 4 +- flang/test/Fir/basic-program.fir | 1 + .../Lower/OpenMP/workdistribute-multiple.f90 | 20 + .../Lower/OpenMP/workdistribute-saxpy-1d.f90 | 39 + .../Lower/OpenMP/workdistribute-saxpy-2d.f90 | 45 + .../Lower/OpenMP/workdistribute-saxpy-3d.f90 | 47 + ...workdistribute-saxpy-and-scalar-assign.f90 | 53 + .../OpenMP/workdistribute-saxpy-two-2d.f90 | 68 + .../OpenMP/workdistribute-scalar-assign.f90 | 29 + .../workdistribute-target-teams-clauses.f90 | 32 + ...workdistribute-teams-unsupported-after.f90 | 22 + ...orkdistribute-teams-unsupported-before.f90 | 22 + .../OpenMP/lower-workdistribute-doloop.mlir | 33 + .../lower-workdistribute-fission-host.mlir | 117 ++ .../lower-workdistribute-fission-target.mlir | 118 ++ .../OpenMP/lower-workdistribute-fission.mlir | 71 + ...-workdistribute-runtime-assign-scalar.mlir | 108 + 20 files changed, 2685 insertions(+), 1 deletion(-) create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp create mode 100644 flang/test/Lower/OpenMP/workdistribute-multiple.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 create mode 100644 flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir create mode 100644 flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index 2b6540cb18d26..b709a9dbfed23 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -109,6 +109,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> { let summary = "Lower workshare construct"; } +def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> { + let summary = "Lower workdistribute construct"; +} + def GenericLoopConversionPass : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> { let summary = "Converts OpenMP generic `omp.loop` to semantically " diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index 0b8bae01dadec..b76e2d2bb997c 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000000000..cfa39e142907c --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,1852 @@ +//===- LowerWorkdistribute.cpp +//-------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering and optimisations of omp.workdistribute. +// +// Fortran array statements are lowered to fir as fir.do_loop unordered. +// lower-workdistribute pass works mainly on identifying fir.do_loop unordered +// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and +// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}. +// It hoists all the other ops outside target region. +// Relaces heap allocation on target with omp.target_allocmem and +// deallocation with omp.target_freemem from host. Also replaces +// runtime function "Assign" with omp_target_memcpy. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +/// This string is used to identify the Fortran-specific runtime FortranAAssign. +static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign"; + +/// The isRuntimeCall function is a utility designed to determine +/// if a given operation is a call to a Fortran-specific runtime function. +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.workdistribute region. +/// Parallelize here refers to dividing into units of work. +static bool shouldParallelize(Operation *op) { + // True if the op is a runtime call to Assign + if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + return true; + } + } + // We cannot parallelize ops with side effects. + // Parallelizable operations should not produce + // values that other operations depend on + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + // We cannot parallelize anything else. + return false; +} + +/// The getPerfectlyNested function is a generic utility for finding +/// a single, "perfectly nested" operation within a parent operation. +template +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// verifyTargetTeamsWorkdistribute method verifies that +/// omp.target { teams { workdistribute { ... } } } is well formed +/// and fails for function calls that don't have lowering implemented yet. +static LogicalResult +verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + + bool foundWorkdistribute = false; + for (auto &op : teams.getOps()) { + if (isa(op)) { + if (foundWorkdistribute) { + emitError(loc, "teams has multiple workdistribute ops.\n"); + return failure(); + } + foundWorkdistribute = true; + continue; + } + // Identify any omp dialect ops present before/after workdistribute. + if (op.getDialect() && isa(op.getDialect()) && + !isa(op)) { + emitError(loc, "teams has omp ops other than workdistribute. Lowering " + "not implemented yet.\n"); + return failure(); + } + } + + omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); + // return if not omp.target + if (!targetOp) + return success(); + + for (auto &op : workdistribute.getOps()) { + if (auto callOp = dyn_cast(op)) { + if (isRuntimeCall(&op)) { + auto funcName = (*callOp.getCallee()).getRootReference().getValue(); + // _FortranAAssign is handled. Other runtime calls are not supported + // in omp.workdistribute yet. + if (funcName == FortranAssignStr) + continue; + else { + emitError(loc, "Runtime call " + funcName + + " lowering not supported for workdistribute yet."); + return failure(); + } + } + } + } + return success(); +} + +/// fissionWorkdistribute method finds the parallelizable ops +/// within teams {workdistribute} region and moves them to their +/// own teams{workdistribute} region. +/// +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() +static FailureOr +fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return failure(); + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; + } + } + + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create(loc); + rewriter.create(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create(loc); + changed = true; + } + } + return changed; +} + +/// Generate omp.parallel operation with an empty region. +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { + auto parallelOp = rewriter.create(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create(loc)); + return; +} + +/// Generate omp.distribute operation with an empty region. +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + +/// Generate loop nest clause operands from fir.do_loop operation. +static void +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +/// Generate omp.wsloop operation with an empty region and +/// clone the body of fir.do_loop operation inside the loop nest region. +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { + + auto wsloopOp = rewriter.create(doLoop.getLoc()); + wsloopOp.setComposite(composite); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create(doLoop->getLoc()); + terminatorOp->erase(); + } +} + +/// workdistributeDoLower method finds the fir.do_loop unoredered +/// nested in teams {workdistribute{fir.do_loop unoredered}} and +/// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. +/// +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } +/// } +static bool +workdistributeDoLower(omp::WorkdistributeOp workdistribute, + SetVector &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + + // Record the target ops to process later + if (auto teamsOp = dyn_cast(workdistribute->getParentOp())) { + auto targetOp = dyn_cast(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + // Generate the nested parallel, distribute, wsloop and loop_nest ops. + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; + } + return false; +} + +/// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array +static bool isEnclosedTypeRefToBoxArray(Type type) { + // Check if it's a reference type + if (auto refType = dyn_cast(type)) { + // Get the referenced type (should be fir.box) + auto referencedType = refType.getEleTy(); + // Check if referenced type is a box + if (auto boxType = dyn_cast(referencedType)) { + // Get the boxed type and check if it's an array + auto boxedType = boxType.getEleTy(); + // Check if boxed type is a sequence (array) + return isa(boxedType); + } + } + return false; +} + +/// Check if the enclosed type in fir.box is scalar (not array) +static bool isEnclosedTypeBoxScalar(Type type) { + // Check if it's a box type + if (auto boxType = dyn_cast(type)) { + // Get the boxed type + auto boxedType = boxType.getEleTy(); + // Check if boxed type is NOT a sequence (array) + return !isa(boxedType); + } + return false; +} + +/// Check if the FortranAAssign call has src as scalar and dest as array +static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { + if (callOp.getNumOperands() < 2) + return false; + auto srcArg = callOp.getOperand(1); + auto destArg = callOp.getOperand(0); + // Both operands should be fir.convert ops + auto srcConvert = srcArg.getDefiningOp(); + auto destConvert = destArg.getDefiningOp(); + if (!srcConvert || !destConvert) { + emitError(callOp->getLoc(), + "Unimplemented: FortranAssign to OpenMP lowering\n"); + return false; + } + // Get the original types before conversion + auto srcOrigType = srcConvert.getValue().getType(); + auto destOrigType = destConvert.getValue().getType(); + + // Check if src is scalar and dest is array + bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType); + bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType); + return srcIsScalar && destIsArray; +} + +/// Convert a flat index to multi-dimensional indices for an array box +/// Example: 2D array with shape (2,4) +/// Col 1 Col 2 Col 3 Col 4 +/// Row 1: (1,1) (1,2) (1,3) (1,4) +/// Row 2: (2,1) (2,2) (2,3) (2,4) +/// +/// extents: (2,4) +/// +/// flatIdx: 0 1 2 3 4 5 6 7 +/// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) +static SmallVector convertFlatToMultiDim(OpBuilder &builder, + Location loc, Value flatIdx, + Value arrayBox) { + // Get array type and rank + auto boxType = cast(arrayBox.getType()); + auto seqType = cast(boxType.getEleTy()); + int rank = seqType.getDimension(); + + // Get all extents + SmallVector extents; + // Get extents for each dimension + for (int i = 0; i < rank; ++i) { + auto dimIdx = builder.create(loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + extents.push_back(boxDims.getResult(1)); + } + + // Convert flat index to multi-dimensional indices + SmallVector indices(rank); + Value temp = flatIdx; + auto c1 = builder.create(loc, 1); + + // Work backwards through dimensions (row-major order) + for (int i = rank - 1; i >= 0; --i) { + Value zeroBasedIdx = builder.create(loc, temp, extents[i]); + // Convert to one-based index + indices[i] = builder.create(loc, zeroBasedIdx, c1); + if (i > 0) { + temp = builder.create(loc, temp, extents[i]); + } + } + + return indices; +} + +/// Calculate the total number of elements in the array box +/// (totalElems = extent(1) * extent(2) * ... * extent(n)) +static Value CalculateTotalElements(OpBuilder &builder, Location loc, + Value arrayBox) { + auto boxType = cast(arrayBox.getType()); + auto seqType = cast(boxType.getEleTy()); + int rank = seqType.getDimension(); + + Value totalElems = nullptr; + for (int i = 0; i < rank; ++i) { + auto dimIdx = builder.create(loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + Value extent = boxDims.getResult(1); + if (i == 0) { + totalElems = extent; + } else { + totalElems = builder.create(loc, totalElems, extent); + } + } + return totalElems; +} + +/// Replace the FortranAAssign runtime call with an unordered do loop +static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, + omp::TeamsOp teamsOp, + omp::WorkdistributeOp workdistribute, + fir::CallOp callOp) { + auto destConvert = callOp.getOperand(0).getDefiningOp(); + auto srcConvert = callOp.getOperand(1).getDefiningOp(); + + Value destBox = destConvert.getValue(); + Value srcBox = srcConvert.getValue(); + + // get defining alloca op of destBox and srcBox + auto destAlloca = destBox.getDefiningOp(); + + if (!destAlloca) { + emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n"); + return; + } + + // get the store op that stores to the alloca + for (auto user : destAlloca->getUsers()) { + if (auto storeOp = dyn_cast(user)) { + destBox = storeOp.getValue(); + break; + } + } + + builder.setInsertionPoint(teamsOp); + // Load destination array box (if it's a reference) + Value arrayBox = destBox; + if (isa(destBox.getType())) + arrayBox = builder.create(loc, destBox); + + auto scalarValue = builder.create(loc, srcBox); + Value scalar = builder.create(loc, scalarValue); + + // Calculate total number of elements (flattened) + auto c0 = builder.create(loc, 0); + auto c1 = builder.create(loc, 1); + Value totalElems = CalculateTotalElements(builder, loc, arrayBox); + + auto *workdistributeBlock = &workdistribute.getRegion().front(); + builder.setInsertionPointToStart(workdistributeBlock); + // Create single unordered loop for flattened array + auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true); + Block *loopBlock = &doLoop.getRegion().front(); + builder.setInsertionPointToStart(doLoop.getBody()); + + auto flatIdx = loopBlock->getArgument(0); + SmallVector indices = + convertFlatToMultiDim(builder, loc, flatIdx, arrayBox); + // Use fir.array_coor for linear addressing + auto elemPtr = fir::ArrayCoorOp::create( + builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, + nullptr, nullptr, ValueRange{indices}, ValueRange{}); + + builder.create(loc, scalar, elemPtr); +} + +/// workdistributeRuntimeCallLower method finds the runtime calls +/// nested in teams {workdistribute{}} and +/// lowers FortranAAssign to unordered do loop if src is scalar and dest is +/// array. Other runtime calls are not handled currently. +static FailureOr +workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, + SetVector &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + bool changed = false; + // Get the target op parent of teams + omp::TargetOp targetOp = dyn_cast(teams->getParentOp()); + SmallVector opsToErase; + for (auto &op : workdistribute.getOps()) { + if (isRuntimeCall(&op)) { + rewriter.setInsertionPoint(&op); + fir::CallOp runtimeCall = cast(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { + // Record the target ops to process later + targetOpsToProcess.insert(targetOp); + replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute, + runtimeCall); + opsToErase.push_back(&op); + changed = true; + } + } + } + } + // Erase the runtime calls that have been replaced. + for (auto *op : opsToErase) { + op->erase(); + } + return changed; +} + +/// teamsWorkdistributeToSingleOp method hoists all the ops inside +/// teams {workdistribute{}} before teams op. +/// +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// +/// If only the terminator remains in teams after hoisting, we erase teams op. +static bool +teamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, + SetVector &targetOpsToProcess) { + auto workdistributeOp = getPerfectlyNested(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + // Record the target ops to process later + for (auto &op : workdistributeBlock.getOperations()) { + if (shouldParallelize(&op)) { + auto targetOp = dyn_cast(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + } + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} + +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +FailureOr splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + emitError(loc, "Target region has no data maps\n"); + return failure(); + } + // Collect all the mapinfo ops + SmallVector mapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast(opr.getDefiningOp()); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector innerMapInfos; + SmallVector outerMapInfos; + // Create new mapinfo ops for the inner target region + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + // For bycopy, we keep the same map type and capture type + // For byref, we change the map type to none and keep the capture type + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + emitError(targetOp->getLoc(), "Unhandled case"); + return failure(); + } + auto innerMapInfo = cast(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); + } + + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + // Create the target data op + auto targetDataOp = rewriter.create( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + // Create the inner target op + auto newTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); + rewriter.replaceOp(targetOp, targetDataOp); + return newTargetOp; +} + +/// getNestedOpToIsolate function is designed to identify a specific teams +/// parallel op within the body of an omp::TargetOp that should be "isolated." +/// This returns a tuple of op, if its first op in targetBlock, or if the op is +/// last op in the traget block. +static std::optional> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +/// Temporary structure to hold the two mapinfo ops +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +/// isPtr checks if the type is a pointer or reference type. +static bool isPtr(Type ty) { + return isa(ty) || isa(ty); +} + +/// getPtrTypeForOmp returns an LLVM pointer type for the given type. +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::ReferenceType::get(ty); +} + +/// allocateTempOmpVar allocates a temporary variable for OpenMP mapping +static TempOmpVar allocateTempOmpVar(Location loc, Type ty, + RewriterBase &rewriter) { + MLIRContext &ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + // Get the appropriate type for allocation + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } else { + allocType = ty; + alloc = rewriter.create(loc, allocType); + } + // Lambda to create mapinfo ops + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create( + loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), + mappingFlags), + rewriter.getAttr( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); + }; + // Create mapinfo ops. + uint64_t mapFrom = + static_cast>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = + static_cast>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +} + +// usedOutsideSplit checks if a value is used outside the split operation. +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +} + +/// isRecomputableAfterFission checks if an operation can be recomputed +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + // If the op has side effects, it cannot be recomputed. + // We consider fir.declare as having no side effects. + return isa(op) || isMemoryEffectFree(op); +} + +/// collectNonRecomputableDeps collects dependencies that cannot be recomputed +static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, + SetVector &nonRecomputable, + SetVector &toCache, + SetVector &toRecompute) { + Operation *op = v.getDefiningOp(); + // If v is a block argument, it must be from the targetOp. + if (!op) { + assert(cast(v).getOwner()->getParentOp() == targetOp); + return; + } + // If the op is in the nonRecomputable set, add it to toCache and return. + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + // Add the op to toRecompute. + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, + toRecompute); +} + +/// createBlockArgsAndMap creates block arguments and maps them +static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, + omp::TargetOp &targetOp, Block *targetBlock, + Block *newTargetBlock, + SmallVector &hostEvalVars, + SmallVector &mapOperands, + SmallVector &allocs, + IRMapping &irMapping) { + // FIRST: Map `host_eval_vars` to block arguments + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < hostEvalVars.size(); ++i) { + Value originalValue; + BlockArgument newArg; + if (i < originalHostEvalVarsSize) { + originalValue = targetBlock->getArgument(i); // Host_eval args come first + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } else { + originalValue = hostEvalVars[i]; + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // SECOND: Map `map_operands` to block arguments + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < mapOperands.size(); ++i) { + Value originalValue; + BlockArgument newArg; + // Map the new arguments from the original block. + if (i < originalMapVarsSize) { + originalValue = targetBlock->getArgument(originalHostEvalVarsSize + + i); // Offset by host_eval count + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + // Map the new arguments from the `allocs`. + else { + originalValue = allocs[i - originalMapVarsSize]; + newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // THIRD: Map `private_vars` to block arguments (if any) + unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size(); + for (unsigned i = 0; i < originalPrivateVarsSize; ++i) { + auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + irMapping.map(originalArg, newArg); + } + return; +} + +/// reloadCacheAndRecompute reloads cached values and recomputes operations +static void reloadCacheAndRecompute( + Location loc, RewriterBase &rewriter, Operation *splitBefore, + omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, + SmallVector &hostEvalVars, SmallVector &mapOperands, + SmallVector &allocs, SetVector &toRecompute, + IRMapping &irMapping) { + // Handle the load operations for the allocs. + rewriter.setInsertionPointToStart(newTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + unsigned hostEvalVarsSize = hostEvalVars.size(); + // Create load operations for each allocated variable. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value original = allocs[i]; + // Get the new block argument for this specific allocated value. + Value newArg = + newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i); + Value restored; + // If the original value is a pointer or reference, load and convert if + // necessary. + if (isPtr(original.getType())) { + restored = rewriter.create(loc, llvmPtrTy, newArg); + if (!isa(original.getType())) + restored = + rewriter.create(loc, original.getType(), restored); + } else { + restored = rewriter.create(loc, newArg); + } + irMapping.map(original, restored); + } + // Clone the operations if they are in the toRecompute set. + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, irMapping); + } +} + +/// Given a teamsOp, navigate down the nested structure to find the +/// innermost LoopNestOp. The expected nesting is: +/// teams -> parallel -> distribute -> wsloop -> loop_nest +static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { + if (teamsOp.getRegion().empty()) + return nullptr; + // Ensure the teams region has a single block. + if (teamsOp.getRegion().getBlocks().size() != 1) + return nullptr; + // Find parallel op inside teams + mlir::omp::ParallelOp parallelOp = nullptr; + // Look for the parallel op in the teams region + for (auto &op : teamsOp.getRegion().front()) { + if (auto parallel = dyn_cast(op)) { + parallelOp = parallel; + break; + } + } + if (!parallelOp) + return nullptr; + + // Find distribute op inside parallel + mlir::omp::DistributeOp distributeOp = nullptr; + for (auto &op : parallelOp.getRegion().front()) { + if (auto distribute = dyn_cast(op)) { + distributeOp = distribute; + break; + } + } + if (!distributeOp) + return nullptr; + + // Find wsloop op inside distribute + mlir::omp::WsloopOp wsloopOp = nullptr; + for (auto &op : distributeOp.getRegion().front()) { + if (auto wsloop = dyn_cast(op)) { + wsloopOp = wsloop; + break; + } + } + if (!wsloopOp) + return nullptr; + + // Find loop_nest op inside wsloop + for (auto &op : wsloopOp.getRegion().front()) { + if (auto loopNest = dyn_cast(op)) { + return loopNest; + } + } + + return nullptr; +} + +/// Generate LLVM constant operations for i32 and i64 types. +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create(loc, i32Ty, attr); +} + +/// Given a box descriptor, extract the base address of the data it describes. +/// If the box descriptor is a reference, load it first. +/// The base address is returned as an i8* pointer. +static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetBaseAddress"); + auto i8Type = builder.getI8Type(); + auto unknownArrayType = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type); + auto i8BoxType = fir::BoxType::get(unknownArrayType); + auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box); + auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox); + return rawAddr; +} + +/// Given a box descriptor, extract the total number of elements in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The total number of elements is returned as an i64 value. +static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetTotalElements"); + auto i64Type = builder.getI64Type(); + return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, extract the size of each element in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The element size is returned as an i64 value. +static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, + Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + auto i64Type = builder.getI64Type(); + return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, compute the total size in bytes of the data it +/// describes. This is done by multiplying the total number of elements by the +/// size of each element. If the box descriptor is a reference, load it first. +/// The total size in bytes is returned as an i64 value. +static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + Value eleSize = genDescriptorGetEleSize(builder, loc, box); + Value totalElements = genDescriptorGetTotalElements(builder, loc, box); + return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); +} + +/// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to +/// retrieve the device pointer corresponding to a given host pointer and device +/// number. If no mapping exists, the original host pointer is returned. +/// Signature: +/// void *omp_get_mapped_ptr(void *host_ptr, int device_num); +static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value hostPtr, + mlir::Value deviceNum, + mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto i32Type = builder.getI32Type(); + auto funcName = "omp_get_mapped_ptr"; + auto funcOp = module.lookupSymbol(funcName); + + if (!funcOp) { + auto funcType = + mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType}); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector args; + args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr)); + args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum)); + auto callOp = fir::CallOp::create(builder, loc, funcOp, args); + auto mappedPtr = callOp.getResult(0); + auto isNull = builder.genIsNullAddr(loc, mappedPtr); + auto convertedHostPtr = + fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr); + auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr, + mappedPtr); + return result; +} + +/// Generate a call to the OpenMP runtime function `omp_target_memcpy` to +/// perform memory copy between host and device or between devices. +/// Signature: +/// int omp_target_memcpy(void *dst, const void *src, size_t length, +/// size_t dst_offset, size_t src_offset, +/// int dst_device, int src_device); +static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value dst, + mlir::Value src, mlir::Value length, + mlir::Value dstOffset, mlir::Value srcOffset, + mlir::Value device, mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto funcName = "omp_target_memcpy"; + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit + auto i32Type = builder.getI32Type(); + auto funcOp = module.lookupSymbol(funcName); + + if (!funcOp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + llvm::SmallVector argTypes = { + voidPtrType, voidPtrType, sizeTType, sizeTType, + sizeTType, i32Type, i32Type}; + auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type}); + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector args{dst, src, length, dstOffset, + srcOffset, device, device}; + fir::CallOp::create(builder, loc, funcOp, args); + return; +} + +/// Generate code to replace a Fortran array assignment call with OpenMP +/// runtime calls to perform the equivalent operation on the device. +/// This involves extracting the source and destination pointers from the +/// Fortran array descriptors, retrieving their mapped device pointers (if any), +/// and invoking `omp_target_memcpy` to copy the data on the device. +static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::CallOp callOp, + mlir::Value device, + mlir::ModuleOp module) { + assert(callOp.getNumResults() == 0 && + "Expected _FortranAAssign to have no results"); + assert(callOp.getNumOperands() >= 2 && + "Expected _FortranAAssign to have at least two operands"); + + // Extract the source and destination pointers from the call operands. + mlir::Value dest = callOp.getOperand(0); + mlir::Value src = callOp.getOperand(1); + + // Get the base addresses of the source and destination arrays. + mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src); + mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest); + + // Get the total size in bytes of the data to be copied. + mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); + + // Retrieve the mapped device pointers for source and destination. + // If no mapping exists, the original host pointer is used. + Value destPtr = + genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); + Value srcPtr = + genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); + Value zero = builder.create(loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); + + // Generate the call to omp_target_memcpy to perform the data copy on the + // device. + genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero, + device, module); +} + +/// Struct to hold the host eval vars corresponding to loop bounds and steps +struct HostEvalVars { + SmallVector lbs; + SmallVector ubs; + SmallVector steps; +}; + +/// moveToHost method clones all the ops from target region outside of it. +/// It hoists runtime function "_FortranAAssign" and replaces it with omp +/// version. Also hoists and replaces fir.allocmem with omp.target_allocmem and +/// fir.freemem with omp.target_freemem +static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module, + struct HostEvalVars &hostEvalVars) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + + // Get the parent target_data op + auto targetDataOp = cast(targetOp->getParentOp()); + if (!targetDataOp) { + emitError(targetOp->getLoc(), + "Expected target op to be inside target_data op"); + return failure(); + } + // create mapping for host_eval_vars + unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) { + Value hostEvalVar = targetOp.getHostEvalVars()[i]; + BlockArgument arg = targetBlock->getArguments()[i]; + mapping.map(arg, hostEvalVar); + } + // create mapping for map_vars + for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) { + Value mapInfo = targetOp.getMapVars()[i]; + BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i]; + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast(op); + // map the block argument to the host-side variable pointer + mapping.map(arg, mapInfoOp.getVarPtr()); + } + // create mapping for private_vars + unsigned mapSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) { + Value privateVar = targetOp.getPrivateVars()[i]; + // The mapping should link the device-side variable to the host-side one. + BlockArgument arg = + targetBlock->getArguments()[hostEvalVarCount + mapSize + i]; + // Map the device-side copy (`arg`) to the host-side value (`privateVar`). + mapping.map(arg, privateVar); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector opsToReplace; + Value device = targetOp.getDevice(); + + // If device is not specified, default to device 0. + if (!device) { + device = genI32Constant(targetOp.getLoc(), rewriter, 0); + } + // Clone all operations. + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + Operation *clonedOp = rewriter.clone(*op, mapping); + // Map the results of the original op to the cloned op. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + if (fir::DeclareOp clonedDeclareOp = dyn_cast(clonedOp)) { + auto originalDeclareOp = cast(op); + Type originalInType = originalDeclareOp.getMemref().getType(); + Type clonedInType = clonedDeclareOp.getMemref().getType(); + + fir::ReferenceType originalRefType = + dyn_cast(originalInType); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + if (!originalRefType && clonedRefType) { + Type clonedEleTy = clonedRefType.getElementType(); + if (clonedEleTy == originalDeclareOp.getType()) { + opsToReplace.push_back(clonedOp); + } + } + } + // Collect the ops to be replaced. + if (isa(clonedOp) || isa(clonedOp)) + opsToReplace.push_back(clonedOp); + // Check for runtime calls to be replaced. + if (isRuntimeCall(clonedOp)) { + fir::CallOp runtimeCall = cast(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + opsToReplace.push_back(clonedOp); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } + } + // Replace fir.allocmem with omp.target_allocmem. + for (Operation *op : opsToReplace) { + if (auto allocOp = dyn_cast(op)) { + rewriter.setInsertionPoint(allocOp); + auto ompAllocmemOp = rewriter.create( + allocOp.getLoc(), rewriter.getI64Type(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto firConvertOp = rewriter.create( + allocOp.getLoc(), allocOp.getResult().getType(), + ompAllocmemOp.getResult()); + rewriter.replaceOp(allocOp, firConvertOp.getResult()); + } + // Replace fir.freemem with omp.target_freemem. + else if (auto freeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(freeOp); + auto firConvertOp = rewriter.create( + freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); + rewriter.create(freeOp.getLoc(), device, + firConvertOp.getResult()); + rewriter.eraseOp(freeOp); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + else if (fir::DeclareOp clonedDeclareOp = dyn_cast(op)) { + Type clonedInType = clonedDeclareOp.getMemref().getType(); + fir::ReferenceType clonedRefType = + dyn_cast(clonedInType); + Type clonedEleTy = clonedRefType.getElementType(); + rewriter.setInsertionPoint(op); + Value loadedValue = rewriter.create( + clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); + } + // Replace runtime calls with omp versions. + else if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + rewriter.setInsertionPoint(op); + fir::FirOpBuilder builder{rewriter, op}; + + mlir::Location loc = runtimeCall.getLoc(); + genFortranAssignOmpReplacement(builder, loc, runtimeCall, device, + module); + rewriter.eraseOp(op); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } else { + emitError(op->getLoc(), "Unhandled op hoisting."); + return failure(); + } + } + + // Update the host_eval_vars to use the mapped values. + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { + hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]); + hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]); + hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return success(); +} + +/// Result of isolateOp method +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +/// computeAllocsCacheRecomputable method computes the allocs needed to cache +/// the values that are used outside the split point. It also computes the ops +/// that need to be cached and the ops that can be recomputed after the split. +static void computeAllocsCacheRecomputable( + omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector &preMapOperands, SmallVector &postMapOperands, + SmallVector &allocs, SmallVector &requiredVals, + SetVector &nonRecomputable, SetVector &toCache, + SetVector &toRecompute) { + auto *targetBlock = &targetOp.getRegion().front(); + // Find all values that are used outside the split point. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + // Check if any of the results are used outside the split point. + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) { + requiredVals.push_back(res); + } + } + // If the op is not recomputable, add it to the nonRecomputable set. + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) { + nonRecomputable.insert(&*it); + } + } + // For each required value, collect its dependencies. + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, + toRecompute); + // For each op in toCache, create an alloc and update the pre and post map + // operands. + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = + allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } +} + +/// genPreTargetOp method generates the preTargetOp that contains all the ops +/// before the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the store operations for the allocs. +static omp::TargetOp +genPreTargetOp(omp::TargetOp targetOp, SmallVector &preMapOperands, + SmallVector &allocs, Operation *splitBeforeOp, + RewriterBase &rewriter, struct HostEvalVars &hostEvalVars, + bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector preHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of preTargetOp + omp::TargetOp preTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, + preHostEvalVars, preMapOperands, allocs, preMapping); + + // Handle the store operations for the allocs. + rewriter.setInsertionPointToStart(preTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + // Clone the original operations. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + rewriter.clone(*it, preMapping); + } + + unsigned originalHostEvalVarsSize = preHostEvalVars.size(); + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value originalResult = allocs[i]; + Value toStore = preMapping.lookup(originalResult); + // Get the new block argument for this specific allocated value. + Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + // Create the store operation. + if (isPtr(originalResult.getType())) { + if (!isa(toStore.getType())) + toStore = rewriter.create(loc, llvmPtrTy, toStore); + rewriter.create(loc, toStore, newArg); + } else { + rewriter.create(loc, toStore, newArg); + } + } + rewriter.create(loc); + + // Update hostEvalVars with the mapped values for the loop bounds if we have + // a loopNestOp and we are not generating code for the target device. + omp::LoopNestOp loopNestOp = + getLoopNestFromTeams(cast(splitBeforeOp)); + if (loopNestOp && !isTargetDevice) { + for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) { + Value lb = loopNestOp.getLoopLowerBounds()[i]; + Value ub = loopNestOp.getLoopUpperBounds()[i]; + Value step = loopNestOp.getLoopSteps()[i]; + + hostEvalVars.lbs.push_back(preMapping.lookup(lb)); + hostEvalVars.ubs.push_back(preMapping.lookup(ub)); + hostEvalVars.steps.push_back(preMapping.lookup(step)); + } + } + + return preTargetOp; +} + +/// genIsolatedTargetOp method generates the isolatedTargetOp that contains the +/// ops between the split point. It also creates the block arguments and maps +/// the values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp +genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, + Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector &allocs, + SetVector &toRecompute, + struct HostEvalVars &hostEvalVars, bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector isolatedHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of isolatedTargetOp + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + isolatedHostEvalVars.append(hostEvalVars.lbs.begin(), + hostEvalVars.lbs.end()); + isolatedHostEvalVars.append(hostEvalVars.ubs.begin(), + hostEvalVars.ubs.end()); + isolatedHostEvalVars.append(hostEvalVars.steps.begin(), + hostEvalVars.steps.end()); + } + // Create the isolated target op + omp::TargetOp isolatedTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + IRMapping isolatedMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, isolatedMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, toRecompute, + isolatedMapping); + + // Clone the original operations. + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create(loc); + + // update the loop bounds in the isolatedTargetOp if we have host_eval vars + // and we are not generating code for the target device. + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + omp::TeamsOp teamsOp; + for (auto &op : *isolatedTargetBlock) { + if (isa(&op)) + teamsOp = cast(&op); + } + assert(teamsOp && "No teamsOp found in isolated target region"); + // Get the loopNestOp inside the teamsOp + auto loopNestOp = getLoopNestFromTeams(teamsOp); + // Get the BlockArgs related to host_eval vars and update loop_nest bounds + // to them + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + unsigned index = originalHostEvalVarsSize; + // Replace loop bounds with the block arguments passed down via host_eval + SmallVector lbs, ubs, steps; + + // Collect new lb/ub/step values from target block args + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) + lbs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) + ubs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) + steps.push_back(isolatedTargetBlock->getArgument(index++)); + + // Reset the loop bounds + loopNestOp.getLoopLowerBoundsMutable().assign(lbs); + loopNestOp.getLoopUpperBoundsMutable().assign(ubs); + loopNestOp.getLoopStepsMutable().assign(steps); + } + + return isolatedTargetOp; +} + +/// genPostTargetOp method generates the postTargetOp that contains all the ops +/// after the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, + Operation *splitBeforeOp, + SmallVector &postMapOperands, + RewriterBase &rewriter, + SmallVector &allocs, + SetVector &toRecompute) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector postHostEvalVars{targetOp.getHostEvalVars()}; + // Create the post target op + omp::TargetOp postTargetOp = rewriter.create( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + // Create the block for postTargetOp + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock, + postHostEvalVars, postMapOperands, allocs, postMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + postTargetBlock, postHostEvalVars, postMapOperands, + allocs, toRecompute, postMapping); + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + // Clone the original operations after the split point. + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + return postTargetOp; +} + +/// isolateOp method rewrites a omp.target_data { omp.target } in to +/// omp.target_data { +/// // preTargetOp region contains ops before splitBeforeOp. +/// omp.target {} +/// // isolatedTargetOp region contains splitBeforeOp, +/// omp.target {} +/// // postTargetOp region contains ops after splitBeforeOp. +/// omp.target {} +/// } +/// It also handles the mapping of variables and the caching/recomputing +/// of values as needed. +static FailureOr isolateOp(Operation *splitBeforeOp, + bool splitAfter, RewriterBase &rewriter, + mlir::ModuleOp module, + bool isTargetDevice) { + auto targetOp = cast(splitBeforeOp->getParentOp()); + assert(targetOp); + rewriter.setInsertionPoint(targetOp); + + // Prepare the map operands for preTargetOp and postTargetOp + auto preMapOperands = SmallVector(targetOp.getMapVars()); + auto postMapOperands = SmallVector(targetOp.getMapVars()); + + // Vectors to hold analysis results + SmallVector requiredVals; + SetVector toCache; + SetVector toRecompute; + SetVector nonRecomputable; + SmallVector allocs; + struct HostEvalVars hostEvalVars; + + // Analyze the ops in target region to determine which ops need to be + // cached and which ops need to be recomputed + computeAllocsCacheRecomputable( + targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands, + allocs, requiredVals, nonRecomputable, toCache, toRecompute); + + rewriter.setInsertionPoint(targetOp); + + // Generate the preTargetOp that contains all the ops before splitBeforeOp. + auto preTargetOp = + genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter, + hostEvalVars, isTargetDevice); + + // Move the ops of preTarget to host. + auto res = moveToHost(preTargetOp, rewriter, module, hostEvalVars); + if (failed(res)) + return failure(); + rewriter.setInsertionPoint(targetOp); + + // Generate the isolatedTargetOp + omp::TargetOp isolatedTargetOp = + genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter, + allocs, toRecompute, hostEvalVars, isTargetDevice); + + omp::TargetOp postTargetOp = nullptr; + // Generate the postTargetOp that contains all the ops after splitBeforeOp. + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands, + rewriter, allocs, toRecompute); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +/// Recursively fission target ops until no more nested ops can be isolated. +static LogicalResult fissionTarget(omp::TargetOp targetOp, + RewriterBase &rewriter, + mlir::ModuleOp module, bool isTargetDevice) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + struct HostEvalVars hostEvalVars; + return moveToHost(targetOp, rewriter, module, hostEvalVars); + } + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + // Recursively isolate the target op. + if (splitBefore && splitAfter) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + return fissionTarget((*res).postTargetOp, rewriter, module, isTargetDevice); + } + // Isolate only before the op. + if (splitBefore) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + } else { + emitError(toIsolate->getLoc(), "Unhandled case in fissionTarget"); + return failure(); + } + return success(); +} + +/// Pass to lower omp.workdistribute ops. +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase { +public: + void runOnOperation() override { + MLIRContext &context = getContext(); + auto moduleOp = getOperation(); + bool changed = false; + SetVector targetOpsToProcess; + auto verify = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + if (failed(verifyTargetTeamsWorkdistribute(workdistribute))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (verify.wasInterrupted()) + return signalPassFailure(); + + auto fission = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = fissionWorkdistribute(workdistribute); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (fission.wasInterrupted()) + return signalPassFailure(); + + auto rtCallLower = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = workdistributeRuntimeCallLower(workdistribute, + targetOpsToProcess); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (rtCallLower.wasInterrupted()) + return signalPassFailure(); + + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= workdistributeDoLower(workdistribute, targetOpsToProcess); + }); + + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess); + }); + if (changed) { + bool isTargetDevice = + llvm::cast(*moduleOp) + .getIsTargetDevice(); + IRRewriter rewriter(&context); + for (auto targetOp : targetOpsToProcess) { + auto res = splitTargetData(targetOp, rewriter); + if (failed(res)) + return signalPassFailure(); + if (*res) { + if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice))) + return signalPassFailure(); + } + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 05d5ed141e7e4..34e5fa73111ab 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -294,8 +294,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, addNestedPassToAllTopLevelOperations( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP) + if (enableOpenMP) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } } /// Create a pass pipeline for handling certain OpenMP transformations needed diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index c9fe53bf093a1..3a94ea8a476a1 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -69,6 +69,7 @@ func.func @_QQmain() { // PASSES-NEXT: InlineHLFIRAssign // PASSES-NEXT: ConvertHLFIRtoFIR // PASSES-NEXT: LowerWorkshare +// PASSES-NEXT: LowerWorkdistribute // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Lower/OpenMP/workdistribute-multiple.f90 b/flang/test/Lower/OpenMP/workdistribute-multiple.f90 new file mode 100644 index 0000000000000..97f24e13716d2 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-multiple.f90 @@ -0,0 +1,20 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: {{.*}} teams has multiple workdistribute ops. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + + !$omp workdistribute + y = a * y + x + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 new file mode 100644 index 0000000000000..b2dbc0f15121e --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-1d.f90 @@ -0,0 +1,39 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 new file mode 100644 index 0000000000000..09e1211541edb --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-2d.f90 @@ -0,0 +1,45 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 new file mode 100644 index 0000000000000..cf5d0234edb39 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-3d.f90 @@ -0,0 +1,47 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols, depth) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols, depth + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols, depth) :: x, y + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + ! CHECK: fir.do_loop + + !$omp target teams workdistribute + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols, depth) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols, depth + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols, depth) :: x, y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + ! CHECK: fir.do_loop + + !$omp teams workdistribute + y = a * x + y + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 new file mode 100644 index 0000000000000..516c4603bd5da --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-and-scalar-assign.f90 @@ -0,0 +1,53 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp target teams workdistribute + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = a * x + y + + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = 2.0_real32 + + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams workdistribute + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + y = a * x + y + + ! CHECK: fir.call @_FortranAAssign + y = 2.0_real32 + + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 new file mode 100644 index 0000000000000..4aeb2e89140cc --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-saxpy-two-2d.f90 @@ -0,0 +1,68 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + !$omp target teams workdistribute + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * x + y + + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * y + x + + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute(a, x, y, rows, cols) + use iso_fortran_env + implicit none + + integer, intent(in) :: rows, cols + real(kind=real32) :: a + real(kind=real32), dimension(rows, cols) :: x, y + + !$omp teams workdistribute + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * x + y + + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + ! CHECK: fir.do_loop + + y = a * y + x + + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 new file mode 100644 index 0000000000000..3062b3598b8ae --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-scalar-assign.f90 @@ -0,0 +1,29 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute_scalar_assign +subroutine target_teams_workdistribute_scalar_assign() + integer :: aa(10) + + ! CHECK: omp.target_data + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.parallel + ! CHECK: omp.distribute + ! CHECK: omp.wsloop + ! CHECK: omp.loop_nest + + !$omp target teams workdistribute + aa = 20 + !$omp end target teams workdistribute + +end subroutine target_teams_workdistribute_scalar_assign + +! CHECK-LABEL: func @_QPteams_workdistribute_scalar_assign +subroutine teams_workdistribute_scalar_assign() + integer :: aa(10) + ! CHECK: fir.call @_FortranAAssign + !$omp teams workdistribute + aa = 20 + !$omp end teams workdistribute + +end subroutine teams_workdistribute_scalar_assign diff --git a/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 b/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 new file mode 100644 index 0000000000000..4a08e53bc316a --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-target-teams-clauses.f90 @@ -0,0 +1,32 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +! CHECK: omp.target_data map_entries({{.*}}) +! CHECK: omp.target thread_limit({{.*}}) host_eval({{.*}}) map_entries({{.*}}) +! CHECK: omp.teams num_teams({{.*}}) +! CHECK: omp.parallel +! CHECK: omp.distribute +! CHECK: omp.wsloop +! CHECK: omp.loop_nest + +subroutine target_teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + integer :: i + + a = 2.0_real32 + x = [(real(i, real32), i = 1, 10)] + y = [(real(i * 0.5, real32), i = 1, 10)] + + !$omp target teams workdistribute & + !$omp& num_teams(4) & + !$omp& thread_limit(8) & + !$omp& default(shared) & + !$omp& private(i) & + !$omp& map(to: x) & + !$omp& map(tofrom: y) + y = a * x + y + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute diff --git a/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 new file mode 100644 index 0000000000000..cf8902718f2ee --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-after.f90 @@ -0,0 +1,22 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: {{.*}} teams has omp ops other than workdistribute. Lowering not implemented yet. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + + !$omp distribute + do i = 1, 10 + x(i) = real(i, kind=real32) + end do + !$omp end distribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 new file mode 100644 index 0000000000000..d957e147f9e04 --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute-teams-unsupported-before.f90 @@ -0,0 +1,22 @@ +! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=60 %s -o - 2>&1 | FileCheck %s + +! CHECK: error: {{.*}} teams has omp ops other than workdistribute. Lowering not implemented yet. +! CHECK-LABEL: func @_QPteams_workdistribute_1 +subroutine teams_workdistribute_1() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + !$omp teams + + !$omp distribute + do i = 1, 10 + x(i) = real(i, kind=real32) + end do + !$omp end distribute + + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end teams +end subroutine teams_workdistribute_1 diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir new file mode 100644 index 0000000000000..00d10d6264ec9 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir @@ -0,0 +1,33 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @x({{.*}}) +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +func.func @x(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref) { + omp.teams { + omp.workdistribute { + fir.do_loop %iv = %lb to %ub step %step unordered { + %zero = arith.constant 0 : index + fir.store %zero to %addr : !fir.ref + } + omp.terminator + } + omp.terminator + } + return +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir new file mode 100644 index 0000000000000..04e60ca8bbf37 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-host.mlir @@ -0,0 +1,117 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s +// Test lowering of workdistribute after fission on host device. + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"} +// CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target host_eval(%[[VAL_24]] -> %[[VAL_31:.*]], %[[VAL_25]] -> %[[VAL_32:.*]], %[[VAL_26]] -> %[[VAL_33:.*]] : index, index, index) map_entries(%[[VAL_7]] -> %[[VAL_34:.*]], %[[VAL_8]] -> %[[VAL_35:.*]], %[[VAL_9]] -> %[[VAL_36:.*]], %[[VAL_10]] -> %[[VAL_37:.*]], %[[VAL_13]] -> %[[VAL_38:.*]], %[[VAL_16]] -> %[[VAL_39:.*]], %[[VAL_19]] -> %[[VAL_40:.*]], %[[VAL_22]] -> %[[VAL_41:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref +// CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_39]] : !fir.ref +// CHECK: %[[VAL_44:.*]] = fir.load %[[VAL_40]] : !fir.ref +// CHECK: %[[VAL_45:.*]] = fir.load %[[VAL_41]] : !fir.ref> +// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_43]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_47:.*]]) : index = (%[[VAL_31]]) to (%[[VAL_32]]) inclusive step (%[[VAL_33]]) { +// CHECK: fir.store %[[VAL_46]] to %[[VAL_45]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_48:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_49:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_50:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_51:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_52:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_50]], %[[VAL_50]] : index +// CHECK: fir.store %[[VAL_49]] to %[[VAL_52]] : !fir.heap +// CHECK: %[[VAL_54:.*]] = fir.convert %[[VAL_52]] : (!fir.heap) -> i64 +// CHECK: omp.target_freemem %[[VAL_48]], %[[VAL_54]] : i32, i64 +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false} { +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir new file mode 100644 index 0000000000000..062eb701b52ef --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir @@ -0,0 +1,118 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s +// Test lowering of workdistribute after fission on host device. + +// CHECK-LABEL: func.func @x( +// CHECK: %[[VAL_0:.*]] = fir.alloca index {bindc_name = "lb"} +// CHECK: fir.store %[[ARG0:.*]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_1:.*]] = fir.alloca index {bindc_name = "ub"} +// CHECK: fir.store %[[ARG1:.*]] to %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = fir.alloca index {bindc_name = "step"} +// CHECK: fir.store %[[ARG2:.*]] to %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_4:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_6:.*]] = omp.map.info var_ptr(%[[ARG3:.*]] : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "lb"} +// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "ub"} +// CHECK: %[[VAL_9:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "step"} +// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[ARG3]] : !fir.ref, index) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref {name = "addr"} +// CHECK: omp.target_data map_entries(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { +// CHECK: %[[VAL_11:.*]] = fir.alloca index +// CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_13:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_14:.*]] = fir.alloca index +// CHECK: %[[VAL_15:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_16:.*]] = omp.map.info var_ptr(%[[VAL_14]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_17:.*]] = fir.alloca index +// CHECK: %[[VAL_18:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(from) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_17]] : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_20:.*]] = fir.alloca !fir.heap +// CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(from) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_from"} +// CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_20]] : !fir.ref>, !fir.heap) map_clauses(to) capture(ByRef) -> !fir.ref> {name = "__flang_workdistribute_to"} +// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_1]] : !fir.ref +// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_2]] : !fir.ref +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_25]], %[[VAL_25]] : index +// CHECK: %[[VAL_29:.*]] = omp.target_allocmem %[[VAL_23]] : i32, index, %[[VAL_27]] {uniq_name = "dev_buf"} +// CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (i64) -> !fir.heap +// CHECK: fir.store %[[VAL_24]] to %[[VAL_11]] : !fir.ref +// CHECK: fir.store %[[VAL_25]] to %[[VAL_14]] : !fir.ref +// CHECK: fir.store %[[VAL_26]] to %[[VAL_17]] : !fir.ref +// CHECK: fir.store %[[VAL_30]] to %[[VAL_20]] : !fir.ref> +// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_31:.*]], %[[VAL_8]] -> %[[VAL_32:.*]], %[[VAL_9]] -> %[[VAL_33:.*]], %[[VAL_10]] -> %[[VAL_34:.*]], %[[VAL_13]] -> %[[VAL_35:.*]], %[[VAL_16]] -> %[[VAL_36:.*]], %[[VAL_19]] -> %[[VAL_37:.*]], %[[VAL_22]] -> %[[VAL_38:.*]] : !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref, !fir.ref>) { +// CHECK: %[[VAL_39:.*]] = fir.load %[[VAL_35]] : !fir.ref +// CHECK: %[[VAL_40:.*]] = fir.load %[[VAL_36]] : !fir.ref +// CHECK: %[[VAL_41:.*]] = fir.load %[[VAL_37]] : !fir.ref +// CHECK: %[[VAL_42:.*]] = fir.load %[[VAL_38]] : !fir.ref> +// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : index +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_44:.*]]) : index = (%[[VAL_39]]) to (%[[VAL_40]]) inclusive step (%[[VAL_41]]) { +// CHECK: fir.store %[[VAL_43]] to %[[VAL_42]] : !fir.heap +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: %[[VAL_45:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[VAL_46:.*]] = fir.load %[[VAL_11]] : !fir.ref +// CHECK: %[[VAL_47:.*]] = fir.load %[[VAL_14]] : !fir.ref +// CHECK: %[[VAL_48:.*]] = fir.load %[[VAL_17]] : !fir.ref +// CHECK: %[[VAL_49:.*]] = fir.load %[[VAL_20]] : !fir.ref> +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_47]], %[[VAL_47]] : index +// CHECK: fir.store %[[VAL_46]] to %[[VAL_49]] : !fir.heap +// CHECK: %[[VAL_51:.*]] = fir.convert %[[VAL_49]] : (!fir.heap) -> i64 +// CHECK: omp.target_freemem %[[VAL_45]], %[[VAL_51]] : i32, i64 +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } + + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { +func.func @x(%lb : index, %ub : index, %step : index, %addr : !fir.ref) { + %lb_ref = fir.alloca index {bindc_name = "lb"} + fir.store %lb to %lb_ref : !fir.ref + %ub_ref = fir.alloca index {bindc_name = "ub"} + fir.store %ub to %ub_ref : !fir.ref + %step_ref = fir.alloca index {bindc_name = "step"} + fir.store %step to %step_ref : !fir.ref + + %lb_map = omp.map.info var_ptr(%lb_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "lb"} + %ub_map = omp.map.info var_ptr(%ub_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "ub"} + %step_map = omp.map.info var_ptr(%step_ref : !fir.ref, index) map_clauses(to) capture(ByRef) -> !fir.ref {name = "step"} + %addr_map = omp.map.info var_ptr(%addr : !fir.ref, index) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "addr"} + + omp.target map_entries(%lb_map -> %ARG0, %ub_map -> %ARG1, %step_map -> %ARG2, %addr_map -> %ARG3 : !fir.ref, !fir.ref, !fir.ref, !fir.ref) { + %lb_val = fir.load %ARG0 : !fir.ref + %ub_val = fir.load %ARG1 : !fir.ref + %step_val = fir.load %ARG2 : !fir.ref + %one = arith.constant 1 : index + + %20 = arith.addi %ub_val, %ub_val : index + omp.teams { + omp.workdistribute { + %dev_mem = fir.allocmem index, %one {uniq_name = "dev_buf"} + fir.do_loop %iv = %lb_val to %ub_val step %step_val unordered { + fir.store %20 to %dev_mem : !fir.heap + } + fir.store %lb_val to %dev_mem : !fir.heap + fir.freemem %dev_mem : !fir.heap + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir new file mode 100644 index 0000000000000..c562b7009664d --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir @@ -0,0 +1,71 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// CHECK-LABEL: func.func @test_fission_workdistribute( +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32 +// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) { +// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref) -> () +// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref) -> () +// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref>, index) -> !fir.ref +// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref +// CHECK: } +// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref +// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref +// CHECK: return +// CHECK: } +module { +func.func @regular_side_effect_func(%arg0: !fir.ref) { + return +} +func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref) attributes {fir.runtime} { + return +} +func.func @test_fission_workdistribute(%arr1: !fir.ref>, %arr2: !fir.ref>, %scalar_ref1: !fir.ref, %scalar_ref2: !fir.ref) { + %c0_idx = arith.constant 0 : index + %c1_idx = arith.constant 1 : index + %c9_idx = arith.constant 9 : index + %float_val = arith.constant 5.0 : f32 + omp.teams { + omp.workdistribute { + fir.store %float_val to %scalar_ref1 : !fir.ref + fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered { + %elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref>, index) -> !fir.ref + %loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref + %elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref>, index) -> !fir.ref + fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref + } + fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref) -> () + fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref) -> () + fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx { + %elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref>, index) -> !fir.ref + fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref + } + %loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref + fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref + omp.terminator + } + omp.terminator + } + return +} +} diff --git a/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir new file mode 100644 index 0000000000000..03d5d71df0a82 --- /dev/null +++ b/flang/test/Transforms/OpenMP/lower-workdistribute-runtime-assign-scalar.mlir @@ -0,0 +1,108 @@ +// RUN: fir-opt --lower-workdistribute %s | FileCheck %s + +// Test lowering of workdistribute for a scalar assignment within a target teams workdistribute region. +// The test checks that the scalar assignment is correctly lowered to wsloop and loop_nest operations. + +// Example Fortran code: +// !$omp target teams workdistribute +// y = 3.0_real32 +// !$omp end target teams workdistribute + + +// CHECK-LABEL: func.func @x( +// CHECK: omp.target {{.*}} { +// CHECK: omp.teams { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop { +// CHECK: omp.loop_nest (%[[VAL_73:.*]]) : index = (%[[VAL_66:.*]]) to (%[[VAL_72:.*]]) inclusive step (%[[VAL_67:.*]]) { +// CHECK: %[[VAL_74:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_75:.*]]:3 = fir.box_dims %[[VAL_64:.*]], %[[VAL_74]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_76:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_77:.*]]:3 = fir.box_dims %[[VAL_64]], %[[VAL_76]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_78:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_79:.*]] = arith.remsi %[[VAL_73]], %[[VAL_77]]#1 : index +// CHECK: %[[VAL_80:.*]] = arith.addi %[[VAL_79]], %[[VAL_78]] : index +// CHECK: %[[VAL_81:.*]] = arith.divsi %[[VAL_73]], %[[VAL_77]]#1 : index +// CHECK: %[[VAL_82:.*]] = arith.remsi %[[VAL_81]], %[[VAL_75]]#1 : index +// CHECK: %[[VAL_83:.*]] = arith.addi %[[VAL_82]], %[[VAL_78]] : index +// CHECK: %[[VAL_84:.*]] = fir.array_coor %[[VAL_64]] %[[VAL_83]], %[[VAL_80]] : (!fir.box>, index, index) -> !fir.ref +// CHECK: fir.store %[[VAL_65:.*]] to %[[VAL_84]] : !fir.ref +// CHECK: omp.yield +// CHECK: } +// CHECK: } {omp.composite} +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } {omp.composite} +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: omp.terminator +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: func.func private @_FortranAAssign(!fir.ref>, !fir.box, !fir.ref, i32) attributes {fir.runtime} + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { +func.func @x(%arr : !fir.ref>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c78 = arith.constant 78 : index + %cst = arith.constant 3.000000e+00 : f32 + %0 = fir.alloca i32 + %1 = fir.alloca i32 + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + %194 = arith.subi %c10, %c1 : index + %195 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%194 : index) extent(%c10 : index) stride(%c1 : index) start_idx(%c1 : index) + %196 = arith.subi %c20, %c1 : index + %197 = omp.map.bounds lower_bound(%c0 : index) upper_bound(%196 : index) extent(%c20 : index) stride(%c1 : index) start_idx(%c1 : index) + %198 = omp.map.info var_ptr(%arr : !fir.ref>, f32) map_clauses(implicit, tofrom) capture(ByRef) bounds(%195, %197) -> !fir.ref> {name = "y"} + %199 = omp.map.info var_ptr(%1 : !fir.ref, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref {name = ""} + %200 = omp.map.info var_ptr(%0 : !fir.ref, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref {name = ""} + omp.target map_entries(%198 -> %arg5, %199 -> %arg6, %200 -> %arg7 : !fir.ref>, !fir.ref, !fir.ref) { + %c0_0 = arith.constant 0 : index + %201 = fir.load %arg7 : !fir.ref + %202 = fir.load %arg6 : !fir.ref + %203 = fir.convert %202 : (i32) -> i64 + %204 = fir.convert %201 : (i32) -> i64 + %205 = fir.convert %204 : (i64) -> index + %206 = arith.cmpi sgt, %205, %c0_0 : index + %207 = fir.convert %203 : (i64) -> index + %208 = arith.cmpi sgt, %207, %c0_0 : index + %209 = arith.select %208, %207, %c0_0 : index + %210 = arith.select %206, %205, %c0_0 : index + %211 = fir.shape %210, %209 : (index, index) -> !fir.shape<2> + %212 = fir.declare %arg5(%211) {uniq_name = "_QFFaxpy_array_workdistributeEy"} : (!fir.ref>, !fir.shape<2>) -> !fir.ref> + %213 = fir.embox %212(%211) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + omp.teams { + %214 = fir.alloca !fir.box> {pinned} + omp.workdistribute { + %215 = fir.alloca f32 + %216 = fir.embox %215 : (!fir.ref) -> !fir.box + %217 = fir.shape %210, %209 : (index, index) -> !fir.shape<2> + %218 = fir.embox %212(%217) : (!fir.ref>, !fir.shape<2>) -> !fir.box> + fir.store %218 to %214 : !fir.ref>> + %219 = fir.address_of(@_QQclXf9c642d28e5bba1f07fa9a090b72f4fc) : !fir.ref> + %c39_i32 = arith.constant 39 : i32 + %220 = fir.convert %214 : (!fir.ref>>) -> !fir.ref> + %221 = fir.convert %216 : (!fir.box) -> !fir.box + %222 = fir.convert %219 : (!fir.ref>) -> !fir.ref + fir.call @_FortranAAssign(%220, %221, %222, %c39_i32) : (!fir.ref>, !fir.box, !fir.ref, i32) -> () + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +func.func private @_FortranAAssign(!fir.ref>, !fir.box, !fir.ref, i32) attributes {fir.runtime} + +fir.global linkonce @_QQclXf9c642d28e5bba1f07fa9a090b72f4fc constant : !fir.char<1,78> { + %0 = fir.string_lit "File: /work/github/skc7/llvm-project/build_fomp_reldebinfo/saxpy_tests/\00"(78) : !fir.char<1,78> + fir.has_value %0 : !fir.char<1,78> +} +}