Skip to content

Commit ae29a16

Browse files
authored
[flang][FIR][Mem2Reg] Add supoort for FIR. (#172808)
This patch implements Mem2Reg interfaces for FIR.
1 parent aa85989 commit ae29a16

File tree

4 files changed

+180
-5
lines changed

4 files changed

+180
-5
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "mlir/Dialect/Arith/IR/ArithBase.td"
1818
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
1919
include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
20+
include "mlir/Interfaces/MemorySlotInterfaces.td"
2021
include "mlir/Interfaces/ViewLikeInterface.td"
2122
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
2223
include "flang/Optimizer/Dialect/FIRDialect.td"
@@ -80,7 +81,10 @@ def AnyRefOfConstantSizeAggregateType : TypeConstraint<
8081
// Memory SSA operations
8182
//===----------------------------------------------------------------------===//
8283

83-
def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments]> {
84+
def fir_AllocaOp : fir_Op<"alloca", [
85+
AttrSizedOperandSegments,
86+
DeclareOpInterfaceMethods<PromotableAllocationOpInterface>
87+
]> {
8488
let summary = "allocate storage for a temporary on the stack given a type";
8589
let description = [{
8690
This primitive operation is used to allocate an object on the stack. A
@@ -288,8 +292,11 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
288292
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
289293
}
290294

291-
def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
292-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
295+
def fir_LoadOp : fir_OneResultOp<"load", [
296+
FirAliasTagOpInterface,
297+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
298+
DeclareOpInterfaceMethods<PromotableMemOpInterface>
299+
]> {
293300
let summary = "load a value from a memory reference";
294301
let description = [{
295302
Load a value from a memory reference into an ssa-value (virtual register).
@@ -319,8 +326,11 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
319326
}];
320327
}
321328

322-
def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
323-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
329+
def fir_StoreOp : fir_Op<"store", [
330+
FirAliasTagOpInterface,
331+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
332+
DeclareOpInterfaceMethods<PromotableMemOpInterface>
333+
]> {
324334
let summary = "store an SSA-value to a memory location";
325335

326336
let description = [{

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ inline void registerMLIRPassesForFortranTools() {
129129
mlir::affine::registerAffineLoopTilingPass();
130130
mlir::affine::registerAffineDataCopyGenerationPass();
131131

132+
mlir::registerMem2RegPass();
132133
mlir::registerLowerAffinePass();
133134
}
134135

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,36 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) {
186186
return fir::ReferenceType::get(intype);
187187
}
188188

189+
llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() {
190+
// TODO: support promotion of dynamic allocas
191+
if (isDynamic())
192+
return {};
193+
194+
return {mlir::MemorySlot{getResult(), getAllocatedType()}};
195+
}
196+
197+
mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot,
198+
mlir::OpBuilder &builder) {
199+
return fir::UndefOp::create(builder, getLoc(), slot.elemType);
200+
}
201+
202+
void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot,
203+
mlir::BlockArgument argument,
204+
mlir::OpBuilder &builder) {}
205+
206+
std::optional<mlir::PromotableAllocationOpInterface>
207+
fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot,
208+
mlir::Value defaultValue,
209+
mlir::OpBuilder &builder) {
210+
if (defaultValue && defaultValue.use_empty()) {
211+
assert(mlir::isa<fir::UndefOp>(defaultValue.getDefiningOp()) &&
212+
"Expected undef op to be the default value");
213+
defaultValue.getDefiningOp()->erase();
214+
}
215+
this->erase();
216+
return std::nullopt;
217+
}
218+
189219
mlir::Type fir::AllocaOp::getAllocatedType() {
190220
return mlir::cast<fir::ReferenceType>(getType()).getEleTy();
191221
}
@@ -2861,6 +2891,39 @@ llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() {
28612891
// LoadOp
28622892
//===----------------------------------------------------------------------===//
28632893

2894+
bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) {
2895+
return getMemref() == slot.ptr;
2896+
}
2897+
2898+
bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; }
2899+
2900+
mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot,
2901+
mlir::OpBuilder &builder,
2902+
mlir::Value reachingDef,
2903+
const mlir::DataLayout &dataLayout) {
2904+
return mlir::Value();
2905+
}
2906+
2907+
bool fir::LoadOp::canUsesBeRemoved(
2908+
const mlir::MemorySlot &slot,
2909+
const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
2910+
mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
2911+
const mlir::DataLayout &dataLayout) {
2912+
if (blockingUses.size() != 1)
2913+
return false;
2914+
mlir::Value blockingUse = (*blockingUses.begin())->get();
2915+
return blockingUse == slot.ptr && getMemref() == slot.ptr;
2916+
}
2917+
2918+
mlir::DeletionKind fir::LoadOp::removeBlockingUses(
2919+
const mlir::MemorySlot &slot,
2920+
const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
2921+
mlir::OpBuilder &builder, mlir::Value reachingDefinition,
2922+
const mlir::DataLayout &dataLayout) {
2923+
getResult().replaceAllUsesWith(reachingDefinition);
2924+
return mlir::DeletionKind::Delete;
2925+
}
2926+
28642927
void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
28652928
mlir::Value refVal) {
28662929
if (!refVal) {
@@ -4256,6 +4319,39 @@ llvm::LogicalResult fir::SliceOp::verify() {
42564319
// StoreOp
42574320
//===----------------------------------------------------------------------===//
42584321

4322+
bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; }
4323+
4324+
bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) {
4325+
return getMemref() == slot.ptr;
4326+
}
4327+
4328+
mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot,
4329+
mlir::OpBuilder &builder,
4330+
mlir::Value reachingDef,
4331+
const mlir::DataLayout &dataLayout) {
4332+
return getValue();
4333+
}
4334+
4335+
bool fir::StoreOp::canUsesBeRemoved(
4336+
const mlir::MemorySlot &slot,
4337+
const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
4338+
mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
4339+
const mlir::DataLayout &dataLayout) {
4340+
if (blockingUses.size() != 1)
4341+
return false;
4342+
mlir::Value blockingUse = (*blockingUses.begin())->get();
4343+
return blockingUse == slot.ptr && getMemref() == slot.ptr &&
4344+
getValue() != slot.ptr;
4345+
}
4346+
4347+
mlir::DeletionKind fir::StoreOp::removeBlockingUses(
4348+
const mlir::MemorySlot &slot,
4349+
const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
4350+
mlir::OpBuilder &builder, mlir::Value reachingDefinition,
4351+
const mlir::DataLayout &dataLayout) {
4352+
return mlir::DeletionKind::Delete;
4353+
}
4354+
42594355
mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
42604356
return fir::dyn_cast_ptrEleTy(refType);
42614357
}

flang/test/Fir/mem2reg.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: fir-opt %s --allow-unregistered-dialect --mem2reg --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @basic() -> i32 {
4+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32
5+
// CHECK: return %[[CONSTANT_0]] : i32
6+
// CHECK: }
7+
func.func @basic() -> i32 {
8+
%0 = arith.constant 5 : i32
9+
%1 = fir.alloca i32
10+
fir.store %0 to %1 : !fir.ref<i32>
11+
%2 = fir.load %1 : !fir.ref<i32>
12+
return %2 : i32
13+
}
14+
15+
// -----
16+
17+
// CHECK-LABEL: func.func @default_value() -> i32 {
18+
// CHECK: %[[UNDEFINED_0:.*]] = fir.undefined i32
19+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32
20+
// CHECK: return %[[UNDEFINED_0]] : i32
21+
// CHECK: }
22+
func.func @default_value() -> i32 {
23+
%0 = arith.constant 5 : i32
24+
%1 = fir.alloca i32
25+
%2 = fir.load %1 : !fir.ref<i32>
26+
fir.store %0 to %1 : !fir.ref<i32>
27+
return %2 : i32
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: func.func @basic_float() -> f32 {
33+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5.200000e+00 : f32
34+
// CHECK: return %[[CONSTANT_0]] : f32
35+
// CHECK: }
36+
func.func @basic_float() -> f32 {
37+
%0 = arith.constant 5.2 : f32
38+
%1 = fir.alloca f32
39+
fir.store %0 to %1 : !fir.ref<f32>
40+
%2 = fir.load %1 : !fir.ref<f32>
41+
return %2 : f32
42+
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: func.func @cycle(
47+
// CHECK-SAME: %[[ARG0:.*]]: i64,
48+
// CHECK-SAME: %[[ARG1:.*]]: i1,
49+
// CHECK-SAME: %[[ARG2:.*]]: i64) {
50+
// CHECK: cf.cond_br %[[ARG1]], ^bb1(%[[ARG2]] : i64), ^bb2(%[[ARG2]] : i64)
51+
// CHECK: ^bb1(%[[VAL_0:.*]]: i64):
52+
// CHECK: "test.use"(%[[VAL_0]]) : (i64) -> ()
53+
// CHECK: cf.br ^bb2(%[[ARG0]] : i64)
54+
// CHECK: ^bb2(%[[VAL_1:.*]]: i64):
55+
// CHECK: cf.br ^bb1(%[[VAL_1]] : i64)
56+
// CHECK: }
57+
func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
58+
%alloca = fir.alloca i64
59+
fir.store %arg2 to %alloca : !fir.ref<i64>
60+
cf.cond_br %arg1, ^bb1, ^bb2
61+
^bb1:
62+
%use = fir.load %alloca : !fir.ref<i64>
63+
"test.use"(%use) : (i64) -> ()
64+
fir.store %arg0 to %alloca : !fir.ref<i64>
65+
cf.br ^bb2
66+
^bb2:
67+
cf.br ^bb1
68+
}

0 commit comments

Comments
 (0)