Skip to content

Commit 1ff6e7f

Browse files
committed
remove template and format the code
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent dd960bc commit 1ff6e7f

File tree

3 files changed

+388
-362
lines changed

3 files changed

+388
-362
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

+317
Original file line numberDiff line numberDiff line change
@@ -920,4 +920,321 @@ bool isArgmaxOp(linalg::GenericOp genericOp) {
920920
return true;
921921
}
922922

923+
struct ArgmaxCombinerOps {
924+
Operation *maxOp = nullptr; // arith.maximumf
925+
Operation *selectOp = nullptr; // arith.select
926+
Operation *cmpOp = nullptr; // arith.cmpf
927+
};
928+
929+
// Matches the combiner pattern in a linalg.generic argmax-style reduction:
930+
// Example MLIR:
931+
// %4:2 = linalg.generic {
932+
// indexing_maps = [...],
933+
// iterator_types = ["parallel", "reduction"]
934+
// } ins(%arg0 : tensor<?x128xbf16>) outs(%1, %3 : tensor<?xbf16>,
935+
// tensor<?xi64>) {
936+
// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
937+
// %5 = linalg.index 1 : index
938+
// %6 = arith.index_cast %5 : index to i64
939+
// %7 = arith.maximumf %in, %out : bf16
940+
// %8 = arith.cmpf ogt, %in, %out : bf16
941+
// %9 = arith.select %8, %6, %out_0 : i64
942+
// linalg.yield %7, %9 : bf16, i64
943+
// } -> (tensor<?xbf16>, tensor<?xi64>)
944+
//
945+
// This function extracts the `arith.maximumf`, `arith.cmpf`, and `arith.select`
946+
// operations from the body to facilitate transformations such as split
947+
// reduction.
948+
static FailureOr<ArgmaxCombinerOps>
949+
collectArgmaxCombinerOps(linalg::GenericOp genericOp) {
950+
// if (combinerOps.size() < 3) {
951+
// return genericOp->emitError(
952+
// "combinerOps must have space for exactly 3 elements");
953+
// }
954+
955+
assert(IREE::LinalgExt::isArgmaxOp(genericOp) &&
956+
"expected operation to be an argmax op");
957+
958+
ArgmaxCombinerOps ops;
959+
960+
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
961+
962+
// Extract max value producer: arith.maximumf.
963+
Value maxResult = yieldOp.getOperand(0);
964+
auto maxOp = dyn_cast<arith::MaximumFOp>(maxResult.getDefiningOp());
965+
966+
// Extract index result producer: arith.select.
967+
Value indexResult = yieldOp.getOperand(1);
968+
auto selectOp = dyn_cast<arith::SelectOp>(indexResult.getDefiningOp());
969+
970+
// Extract the condition of the select, expected to be arith.cmpf with
971+
// predicate OGT.
972+
auto cmpOp = dyn_cast<arith::CmpFOp>(selectOp.getCondition().getDefiningOp());
973+
974+
ops.maxOp = maxOp;
975+
ops.selectOp = selectOp;
976+
ops.cmpOp = cmpOp;
977+
978+
return ops;
979+
}
980+
981+
FailureOr<linalg::SplitReductionResult>
982+
splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp,
983+
linalg::ControlSplitReductionFn controlSplitReductionFn) {
984+
assert(IREE::LinalgExt::isArgmaxOp(genericOp) &&
985+
"expected operation to be an argmax op");
986+
987+
OpBuilder::InsertionGuard guard(rewriter);
988+
rewriter.setInsertionPoint(genericOp);
989+
Location loc = genericOp->getLoc();
990+
991+
linalg::SplitReductionOptions control = controlSplitReductionFn(genericOp);
992+
int64_t ratio = control.ratio;
993+
unsigned insertSplitIndex = control.index;
994+
unsigned insertSplitDimension = control.index;
995+
if (ratio <= 1) {
996+
return rewriter.notifyMatchFailure(
997+
genericOp, "split ratio needs to be greater than 1");
998+
}
999+
1000+
SmallVector<unsigned> dims;
1001+
genericOp.getReductionDims(dims);
1002+
1003+
unsigned reductionDim = dims[0];
1004+
if (control.innerParallel) {
1005+
insertSplitDimension = reductionDim + 1;
1006+
}
1007+
1008+
SmallVector<int64_t, 4> loopRanges = genericOp.getStaticLoopRanges();
1009+
int64_t reductionDimSize = loopRanges[reductionDim];
1010+
1011+
// The total number of output elements along this new dimension is
1012+
// reductionDimSize / ratio.
1013+
int64_t outputDimsize = reductionDimSize / ratio;
1014+
1015+
if (reductionDimSize == ShapedType::kDynamic ||
1016+
reductionDimSize % ratio != 0) {
1017+
return rewriter.notifyMatchFailure(
1018+
genericOp, "Reduction dimension not divisible by split ratio");
1019+
}
1020+
1021+
if (insertSplitIndex >
1022+
genericOp.getShape(genericOp.getDpsInitOperand(0)).size()) {
1023+
return rewriter.notifyMatchFailure(genericOp,
1024+
"Insert dimension position too large "
1025+
"compared to intermediate tensor size");
1026+
}
1027+
1028+
FailureOr<ArgmaxCombinerOps> maybeOps = collectArgmaxCombinerOps(genericOp);
1029+
if (failed(maybeOps))
1030+
return rewriter.notifyMatchFailure(genericOp,
1031+
"invalid combiner for argmax");
1032+
1033+
ArgmaxCombinerOps combinerOps = *maybeOps;
1034+
Operation *reductionOp = combinerOps.maxOp;
1035+
1036+
std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
1037+
if (!identity.has_value())
1038+
return rewriter.notifyMatchFailure(
1039+
genericOp, "Unknown identity value for the reduction");
1040+
1041+
SmallVector<Value> newInputs;
1042+
SmallVector<AffineMap> newMaps;
1043+
// Calculate the new shapes and indexing maps of the input operands.
1044+
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
1045+
AffineMap map = genericOp.getMatchingIndexingMap(operand);
1046+
SmallVector<int64_t> newShape;
1047+
SmallVector<AffineExpr> exprs;
1048+
SmallVector<ReassociationIndices> reassociation;
1049+
unsigned index = 0;
1050+
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
1051+
unsigned dim = map.getDimPosition(idx);
1052+
if (reductionDim == dim) {
1053+
if (control.innerParallel) {
1054+
newShape.push_back(ratio); // reduce
1055+
newShape.push_back(genericOp.getShape(operand)[idx] /
1056+
ratio); // parallel (insert)
1057+
exprs.push_back(rewriter.getAffineDimExpr(
1058+
dim < insertSplitDimension ? dim : dim + 1));
1059+
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
1060+
} else {
1061+
newShape.push_back(genericOp.getShape(operand)[idx] /
1062+
ratio); // parallel (insert)
1063+
newShape.push_back(ratio); // reduce
1064+
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
1065+
exprs.push_back(rewriter.getAffineDimExpr(
1066+
dim < insertSplitDimension ? dim : dim + 1));
1067+
}
1068+
reassociation.push_back({index++, index++});
1069+
continue;
1070+
}
1071+
newShape.push_back(genericOp.getShape(operand)[idx]);
1072+
exprs.push_back(rewriter.getAffineDimExpr(
1073+
dim < insertSplitDimension ? dim : dim + 1));
1074+
reassociation.push_back({index++});
1075+
}
1076+
newMaps.push_back(
1077+
AffineMap::get(map.getNumDims() + 1, 0, exprs, genericOp.getContext()));
1078+
// If the shape is unchanged the input doesn't change.
1079+
if (newShape == genericOp.getShape(operand)) {
1080+
newInputs.push_back(operand->get());
1081+
continue;
1082+
}
1083+
Type newType = RankedTensorType::get(
1084+
newShape,
1085+
cast<RankedTensorType>(operand->get().getType()).getElementType());
1086+
1087+
Value newInput = rewriter.create<tensor::ExpandShapeOp>(
1088+
loc, newType, operand->get(), reassociation);
1089+
newInputs.push_back(newInput);
1090+
}
1091+
1092+
SmallVector<SmallVector<int64_t>> newOutputShapes;
1093+
SmallVector<AffineMap> outputMaps;
1094+
for (int i = 0; i < genericOp.getNumDpsInits(); ++i) {
1095+
OpOperand *output = genericOp.getDpsInitOperand(i);
1096+
AffineMap oldOutputMap = genericOp.getMatchingIndexingMap(output);
1097+
ArrayRef<int64_t> oldShape = genericOp.getShape(output);
1098+
SmallVector<int64_t> thisOutputShape;
1099+
1100+
SmallVector<AffineExpr> outputExpr;
1101+
for (unsigned idx = 0; idx <= oldShape.size(); ++idx) {
1102+
if (idx == insertSplitIndex) {
1103+
thisOutputShape.push_back(outputDimsize);
1104+
outputExpr.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
1105+
}
1106+
if (idx < oldShape.size()) {
1107+
thisOutputShape.push_back(oldShape[idx]);
1108+
unsigned dim = oldOutputMap.getDimPosition(idx);
1109+
outputExpr.push_back(rewriter.getAffineDimExpr(
1110+
dim < insertSplitDimension ? dim : dim + 1));
1111+
}
1112+
}
1113+
1114+
AffineMap newOutputMap = AffineMap::get(oldOutputMap.getNumDims() + 1, 0,
1115+
outputExpr, rewriter.getContext());
1116+
newMaps.push_back(newOutputMap);
1117+
newOutputShapes.push_back(thisOutputShape);
1118+
}
1119+
1120+
// Handle dynamic dimensions for identity value tensor.
1121+
SmallVector<Value> dynValDims;
1122+
SmallVector<int64_t> newOutputShape = newOutputShapes[0];
1123+
for (size_t i = 0; i < newOutputShape.size(); ++i) {
1124+
if (ShapedType::isDynamic(newOutputShape[i])) {
1125+
dynValDims.push_back(rewriter.create<tensor::DimOp>(
1126+
loc, genericOp.getDpsInputOperand(0)->get(), i));
1127+
}
1128+
}
1129+
1130+
Type valueElemType = genericOp.getRegionOutputArgs()[0].getType();
1131+
Value emptyValTensor = rewriter.create<tensor::EmptyOp>(
1132+
loc, newOutputShape, valueElemType, dynValDims);
1133+
Value constantOp = rewriter.create<arith::ConstantOp>(loc, *identity);
1134+
Value identityVal =
1135+
rewriter.create<linalg::FillOp>(loc, constantOp, emptyValTensor)
1136+
.getResult(0);
1137+
1138+
// Handle dynamic dimensions for identity index tensor.
1139+
SmallVector<Value> dynIdxDims;
1140+
newOutputShape = newOutputShapes[1];
1141+
for (size_t i = 0; i < newOutputShape.size(); ++i) {
1142+
if (ShapedType::isDynamic(newOutputShape[i])) {
1143+
dynIdxDims.push_back(rewriter.create<tensor::DimOp>(
1144+
loc, genericOp.getDpsInputOperand(0)->get(), i));
1145+
}
1146+
}
1147+
Type idxElemType = genericOp.getRegionOutputArgs()[1].getType();
1148+
Value zeroIdx = rewriter.create<arith::ConstantOp>(
1149+
loc, rewriter.getZeroAttr(idxElemType));
1150+
Value idxInitTensor = rewriter.create<tensor::EmptyOp>(
1151+
loc, newOutputShape, idxElemType, dynIdxDims);
1152+
Value identityIndex =
1153+
rewriter.create<linalg::FillOp>(loc, zeroIdx, idxInitTensor).getResult(0);
1154+
1155+
SmallVector<utils::IteratorType> newIteratorTypes;
1156+
for (auto [index, iteratorType] :
1157+
llvm::enumerate(genericOp.getIteratorTypesArray())) {
1158+
if (insertSplitDimension == index)
1159+
newIteratorTypes.push_back(utils::IteratorType::parallel);
1160+
newIteratorTypes.push_back(iteratorType);
1161+
}
1162+
if (insertSplitDimension == genericOp.getIteratorTypesArray().size()) {
1163+
newIteratorTypes.push_back(utils::IteratorType::parallel);
1164+
}
1165+
1166+
// Create partial linalg.generic op with global index computation.
1167+
Value tileSize = rewriter.create<arith::ConstantIndexOp>(loc, ratio);
1168+
auto partialOp = rewriter.create<linalg::GenericOp>(
1169+
loc, TypeRange{identityVal.getType(), identityIndex.getType()}, newInputs,
1170+
ValueRange{identityVal, identityIndex}, newMaps, newIteratorTypes);
1171+
1172+
rewriter.inlineRegionBefore(genericOp.getRegion(), partialOp.getRegion(),
1173+
partialOp.getRegion().begin());
1174+
1175+
Block &body = partialOp.getRegion().front();
1176+
rewriter.setInsertionPointToStart(&body);
1177+
1178+
unsigned innerIdxDim = reductionDim + 1;
1179+
unsigned outerIdxDim = insertSplitDimension;
1180+
1181+
// Compute global index (gidx) for reduction when the original reduction
1182+
// dimension is split into [outerIdx, innerIdx] using `ratio`. This is used to
1183+
// correctly compute the global index for comparisons and index selection.
1184+
Value outerIdx = rewriter.create<linalg::IndexOp>(loc, outerIdxDim);
1185+
Value innerIdx = rewriter.create<linalg::IndexOp>(loc, innerIdxDim);
1186+
Value offset = rewriter.create<arith::MulIOp>(loc, outerIdx, tileSize);
1187+
Value gidx = rewriter.create<arith::AddIOp>(loc, offset, innerIdx);
1188+
1189+
auto selectOp = dyn_cast<arith::SelectOp>(combinerOps.selectOp);
1190+
Value oldIdx = selectOp.getTrueValue();
1191+
Value newIdx = gidx;
1192+
if (oldIdx.getType() != gidx.getType()) {
1193+
newIdx = rewriter.create<arith::IndexCastOp>(loc, oldIdx.getType(), gidx);
1194+
}
1195+
selectOp.setOperand(1, newIdx);
1196+
rewriter.setInsertionPointAfter(partialOp);
1197+
1198+
unsigned intermRank = newOutputShape.size();
1199+
AffineMap valueMap = rewriter.getMultiDimIdentityMap(intermRank);
1200+
AffineMap indexMap = valueMap;
1201+
SmallVector<utils::IteratorType> reductionIteratorTypes;
1202+
SmallVector<AffineExpr> resultExprs;
1203+
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
1204+
if (insertSplitIndex == i) {
1205+
reductionIteratorTypes.push_back(utils::IteratorType::reduction);
1206+
} else {
1207+
resultExprs.push_back(rewriter.getAffineDimExpr(i));
1208+
reductionIteratorTypes.push_back(utils::IteratorType::parallel);
1209+
}
1210+
}
1211+
1212+
AffineMap outputMap =
1213+
AffineMap::get(intermRank, 0, resultExprs, rewriter.getContext());
1214+
SmallVector<AffineMap> finalReductionMaps = {valueMap, indexMap, outputMap,
1215+
outputMap};
1216+
1217+
// Create block for final reduction region.
1218+
auto finalReduction = rewriter.create<linalg::GenericOp>(
1219+
loc, genericOp.getResultTypes(),
1220+
ValueRange{partialOp.getResult(0), partialOp.getResult(1)},
1221+
genericOp.getDpsInits(), finalReductionMaps, reductionIteratorTypes,
1222+
[combinerOps](OpBuilder &b, Location loc, ValueRange inputs) {
1223+
Operation *clonedMax = b.clone(*combinerOps.maxOp);
1224+
clonedMax->setOperands({inputs[0], inputs[2]});
1225+
Operation *clonedCmp = b.clone(*combinerOps.cmpOp);
1226+
clonedCmp->setOperands({inputs[0], inputs[2]});
1227+
Operation *clonedSel = b.clone(*combinerOps.selectOp);
1228+
clonedSel->setOperands({clonedCmp->getResult(0), inputs[1], inputs[3]});
1229+
b.create<linalg::YieldOp>(
1230+
loc, ValueRange{clonedMax->getResult(0), clonedSel->getResult(0)});
1231+
});
1232+
1233+
rewriter.replaceOp(genericOp, finalReduction.getResults());
1234+
// Init or alloc and fillOp are not applicable for argmax op; set to nullptr.
1235+
return linalg::SplitReductionResult{
1236+
/*initOrAlloc=*/nullptr, /*fillOp=*/nullptr,
1237+
cast<linalg::LinalgOp>(partialOp.getOperation()), finalReduction};
1238+
}
1239+
9231240
} // namespace mlir::iree_compiler::IREE::LinalgExt

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h

+61
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
99

1010
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
11+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1112
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1213
#include "mlir/IR/Attributes.h"
1314
#include "mlir/IR/BuiltinTypes.h"
@@ -227,5 +228,65 @@ bool isaHorizontallyFusedContraction(Operation *op);
227228
/// Check if a linalg.generic is representing an argmax operation.
228229
bool isArgmaxOp(linalg::GenericOp genericOp);
229230

231+
/// Apply transformation to split a linalg.generic argmax reduction
232+
/// into a two-stage reduction using an additional parallel dimension.
233+
/// The transformation first computes a partial argmax over tiles (parallel),
234+
/// then reduces those results into a final result (reduction).
235+
///
236+
/// This pattern is specialized for reductions that yield both the maximum
237+
/// value and its index, using the combination of `arith.maximumf`,
238+
/// `arith.cmpf`, and `arith.select` ops. It assumes a known structure of the
239+
/// region and injects index computations to track global indices.
240+
///
241+
/// Returns the resulting partial and final linalg.generic ops, or failure
242+
/// if the pattern does not match or cannot be split.
243+
///
244+
/// Example:
245+
/// ```
246+
/// // Original argmax op reducing over dim=512
247+
/// %4:2 = linalg.generic {
248+
/// indexing_maps = [...],
249+
/// iterator_types = ["parallel", "reduction"]
250+
/// } ins(%arg0 : tensor<?x512xbf16>)
251+
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
252+
/// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
253+
/// %idx = linalg.index 1 : index
254+
/// %cast = arith.index_cast %idx : index to i64
255+
/// %max = arith.maximumf %in, %out : bf16
256+
/// %cmp = arith.cmpf ogt, %in, %out : bf16
257+
/// %sel = arith.select %cmp, %cast, %out_0 : i64
258+
/// linalg.yield %max, %sel : bf16, i64
259+
/// } -> (tensor<?xbf16>, tensor<?xi64>)
260+
/// ```
261+
/// To:
262+
/// ```
263+
/// // After splitting K=512 into 4 x 128
264+
/// %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x512xbf16>
265+
/// into tensor<?x4x128xbf16>
266+
/// %init_val = linalg.fill ins(%cst : bf16) outs(%empty : tensor<?x4xbf16>)
267+
/// %init_idx = linalg.fill ins(%zero : i64) outs(%empty : tensor<?x4xi64>)
268+
/// %partial:2 = linalg.generic {
269+
/// indexing_maps = [...],
270+
/// iterator_types = ["parallel", "reduction"]
271+
/// } ins(%expanded : tensor<?x4x128xbf16>)
272+
/// outs(%init_val, %init_idx : tensor<?x4xbf16>, tensor<?x4xi64>) {
273+
/// // compute global index: outer_idx * 128 + inner_idx
274+
/// ...
275+
/// }
276+
///
277+
/// // Final argmax over the tile dimension (dim=1 of ?x4)
278+
/// %final:2 = linalg.generic {
279+
/// indexing_maps = [...],
280+
/// iterator_types = ["reduction"]
281+
/// } ins(%partial#0, %partial#1)
282+
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
283+
/// // same combiner: maximumf, cmpf, select
284+
/// ...
285+
/// }
286+
/// ```
287+
FailureOr<linalg::SplitReductionResult>
288+
splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp,
289+
linalg::ControlSplitReductionFn controlSplitReductionFn);
290+
230291
} // namespace mlir::iree_compiler::IREE::LinalgExt
231292
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_

0 commit comments

Comments
 (0)