Skip to content

Commit e59e177

Browse files
committed
move the code to transforms.h/cpp
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent 1ff6e7f commit e59e177

File tree

6 files changed

+375
-378
lines changed

6 files changed

+375
-378
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp

+313
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
99
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
1010
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
11+
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1112
#include "llvm/ADT/STLExtras.h"
1213
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -389,4 +390,316 @@ splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp,
389390
return success();
390391
}
391392

393+
struct ArgmaxCombinerOps {
394+
Operation *maxOp = nullptr; // arith.maximumf
395+
Operation *selectOp = nullptr; // arith.select
396+
Operation *cmpOp = nullptr; // arith.cmpf
397+
};
398+
399+
// Matches the combiner pattern in a linalg.generic argmax-style reduction:
400+
// Example MLIR:
401+
// %4:2 = linalg.generic {
402+
// indexing_maps = [...],
403+
// iterator_types = ["parallel", "reduction"]
404+
// } ins(%arg0 : tensor<?x128xbf16>) outs(%1, %3 : tensor<?xbf16>,
405+
// tensor<?xi64>) {
406+
// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
407+
// %5 = linalg.index 1 : index
408+
// %6 = arith.index_cast %5 : index to i64
409+
// %7 = arith.maximumf %in, %out : bf16
410+
// %8 = arith.cmpf ogt, %in, %out : bf16
411+
// %9 = arith.select %8, %6, %out_0 : i64
412+
// linalg.yield %7, %9 : bf16, i64
413+
// } -> (tensor<?xbf16>, tensor<?xi64>)
414+
//
415+
// This function extracts the `arith.maximumf`, `arith.cmpf`, and `arith.select`
416+
// operations from the body to facilitate transformations such as split
417+
// reduction.
418+
static FailureOr<ArgmaxCombinerOps>
419+
collectArgmaxCombinerOps(linalg::GenericOp genericOp) {
420+
421+
assert(isArgmaxOp(genericOp) && "expected operation to be an argmax op");
422+
423+
ArgmaxCombinerOps ops;
424+
425+
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
426+
427+
// Extract max value producer: arith.maximumf.
428+
Value maxResult = yieldOp.getOperand(0);
429+
auto maxOp = dyn_cast<arith::MaximumFOp>(maxResult.getDefiningOp());
430+
431+
// Extract index result producer: arith.select.
432+
Value indexResult = yieldOp.getOperand(1);
433+
auto selectOp = dyn_cast<arith::SelectOp>(indexResult.getDefiningOp());
434+
435+
// Extract the condition of the select, expected to be arith.cmpf with
436+
// predicate OGT.
437+
auto cmpOp = dyn_cast<arith::CmpFOp>(selectOp.getCondition().getDefiningOp());
438+
439+
ops.maxOp = maxOp;
440+
ops.selectOp = selectOp;
441+
ops.cmpOp = cmpOp;
442+
443+
return ops;
444+
}
445+
446+
FailureOr<linalg::SplitReductionResult>
447+
splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp,
448+
linalg::ControlSplitReductionFn controlSplitReductionFn) {
449+
assert(IREE::LinalgExt::isArgmaxOp(genericOp) &&
450+
"expected operation to be an argmax op");
451+
452+
OpBuilder::InsertionGuard guard(rewriter);
453+
rewriter.setInsertionPoint(genericOp);
454+
Location loc = genericOp->getLoc();
455+
456+
linalg::SplitReductionOptions control = controlSplitReductionFn(genericOp);
457+
int64_t ratio = control.ratio;
458+
unsigned insertSplitIndex = control.index;
459+
unsigned insertSplitDimension = control.index;
460+
if (ratio <= 1) {
461+
return rewriter.notifyMatchFailure(
462+
genericOp, "split ratio needs to be greater than 1");
463+
}
464+
465+
SmallVector<unsigned> dims;
466+
genericOp.getReductionDims(dims);
467+
468+
unsigned reductionDim = dims[0];
469+
if (control.innerParallel) {
470+
insertSplitDimension = reductionDim + 1;
471+
}
472+
473+
SmallVector<int64_t, 4> loopRanges = genericOp.getStaticLoopRanges();
474+
int64_t reductionDimSize = loopRanges[reductionDim];
475+
476+
// The total number of output elements along this new dimension is
477+
// reductionDimSize / ratio.
478+
int64_t outputDimsize = reductionDimSize / ratio;
479+
480+
if (reductionDimSize == ShapedType::kDynamic ||
481+
reductionDimSize % ratio != 0) {
482+
return rewriter.notifyMatchFailure(
483+
genericOp, "Reduction dimension not divisible by split ratio");
484+
}
485+
486+
if (insertSplitIndex >
487+
genericOp.getShape(genericOp.getDpsInitOperand(0)).size()) {
488+
return rewriter.notifyMatchFailure(genericOp,
489+
"Insert dimension position too large "
490+
"compared to intermediate tensor size");
491+
}
492+
493+
FailureOr<ArgmaxCombinerOps> maybeOps = collectArgmaxCombinerOps(genericOp);
494+
if (failed(maybeOps))
495+
return rewriter.notifyMatchFailure(genericOp,
496+
"invalid combiner for argmax");
497+
498+
ArgmaxCombinerOps combinerOps = *maybeOps;
499+
Operation *reductionOp = combinerOps.maxOp;
500+
501+
std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
502+
if (!identity.has_value())
503+
return rewriter.notifyMatchFailure(
504+
genericOp, "Unknown identity value for the reduction");
505+
506+
SmallVector<Value> newInputs;
507+
SmallVector<AffineMap> newMaps;
508+
// Calculate the new shapes and indexing maps of the input operands.
509+
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
510+
AffineMap map = genericOp.getMatchingIndexingMap(operand);
511+
SmallVector<int64_t> newShape;
512+
SmallVector<AffineExpr> exprs;
513+
SmallVector<ReassociationIndices> reassociation;
514+
unsigned index = 0;
515+
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
516+
unsigned dim = map.getDimPosition(idx);
517+
if (reductionDim == dim) {
518+
if (control.innerParallel) {
519+
newShape.push_back(ratio); // reduce
520+
newShape.push_back(genericOp.getShape(operand)[idx] /
521+
ratio); // parallel (insert)
522+
exprs.push_back(rewriter.getAffineDimExpr(
523+
dim < insertSplitDimension ? dim : dim + 1));
524+
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
525+
} else {
526+
newShape.push_back(genericOp.getShape(operand)[idx] /
527+
ratio); // parallel (insert)
528+
newShape.push_back(ratio); // reduce
529+
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
530+
exprs.push_back(rewriter.getAffineDimExpr(
531+
dim < insertSplitDimension ? dim : dim + 1));
532+
}
533+
reassociation.push_back({index++, index++});
534+
continue;
535+
}
536+
newShape.push_back(genericOp.getShape(operand)[idx]);
537+
exprs.push_back(rewriter.getAffineDimExpr(
538+
dim < insertSplitDimension ? dim : dim + 1));
539+
reassociation.push_back({index++});
540+
}
541+
newMaps.push_back(
542+
AffineMap::get(map.getNumDims() + 1, 0, exprs, genericOp.getContext()));
543+
// If the shape is unchanged the input doesn't change.
544+
if (newShape == genericOp.getShape(operand)) {
545+
newInputs.push_back(operand->get());
546+
continue;
547+
}
548+
Type newType = RankedTensorType::get(
549+
newShape,
550+
cast<RankedTensorType>(operand->get().getType()).getElementType());
551+
552+
Value newInput = rewriter.create<tensor::ExpandShapeOp>(
553+
loc, newType, operand->get(), reassociation);
554+
newInputs.push_back(newInput);
555+
}
556+
557+
SmallVector<SmallVector<int64_t>> newOutputShapes;
558+
SmallVector<AffineMap> outputMaps;
559+
for (int i = 0; i < genericOp.getNumDpsInits(); ++i) {
560+
OpOperand *output = genericOp.getDpsInitOperand(i);
561+
AffineMap oldOutputMap = genericOp.getMatchingIndexingMap(output);
562+
ArrayRef<int64_t> oldShape = genericOp.getShape(output);
563+
SmallVector<int64_t> thisOutputShape;
564+
565+
SmallVector<AffineExpr> outputExpr;
566+
for (unsigned idx = 0; idx <= oldShape.size(); ++idx) {
567+
if (idx == insertSplitIndex) {
568+
thisOutputShape.push_back(outputDimsize);
569+
outputExpr.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
570+
}
571+
if (idx < oldShape.size()) {
572+
thisOutputShape.push_back(oldShape[idx]);
573+
unsigned dim = oldOutputMap.getDimPosition(idx);
574+
outputExpr.push_back(rewriter.getAffineDimExpr(
575+
dim < insertSplitDimension ? dim : dim + 1));
576+
}
577+
}
578+
579+
AffineMap newOutputMap = AffineMap::get(oldOutputMap.getNumDims() + 1, 0,
580+
outputExpr, rewriter.getContext());
581+
newMaps.push_back(newOutputMap);
582+
newOutputShapes.push_back(thisOutputShape);
583+
}
584+
585+
// Handle dynamic dimensions for identity value tensor.
586+
SmallVector<Value> dynValDims;
587+
SmallVector<int64_t> newOutputShape = newOutputShapes[0];
588+
for (size_t i = 0; i < newOutputShape.size(); ++i) {
589+
if (ShapedType::isDynamic(newOutputShape[i])) {
590+
dynValDims.push_back(rewriter.create<tensor::DimOp>(
591+
loc, genericOp.getDpsInputOperand(0)->get(), i));
592+
}
593+
}
594+
595+
Type valueElemType = genericOp.getRegionOutputArgs()[0].getType();
596+
Value emptyValTensor = rewriter.create<tensor::EmptyOp>(
597+
loc, newOutputShape, valueElemType, dynValDims);
598+
Value constantOp = rewriter.create<arith::ConstantOp>(loc, *identity);
599+
Value identityVal =
600+
rewriter.create<linalg::FillOp>(loc, constantOp, emptyValTensor)
601+
.getResult(0);
602+
603+
// Handle dynamic dimensions for identity index tensor.
604+
SmallVector<Value> dynIdxDims;
605+
newOutputShape = newOutputShapes[1];
606+
for (size_t i = 0; i < newOutputShape.size(); ++i) {
607+
if (ShapedType::isDynamic(newOutputShape[i])) {
608+
dynIdxDims.push_back(rewriter.create<tensor::DimOp>(
609+
loc, genericOp.getDpsInputOperand(0)->get(), i));
610+
}
611+
}
612+
Type idxElemType = genericOp.getRegionOutputArgs()[1].getType();
613+
Value zeroIdx = rewriter.create<arith::ConstantOp>(
614+
loc, rewriter.getZeroAttr(idxElemType));
615+
Value idxInitTensor = rewriter.create<tensor::EmptyOp>(
616+
loc, newOutputShape, idxElemType, dynIdxDims);
617+
Value identityIndex =
618+
rewriter.create<linalg::FillOp>(loc, zeroIdx, idxInitTensor).getResult(0);
619+
620+
SmallVector<utils::IteratorType> newIteratorTypes;
621+
for (auto [index, iteratorType] :
622+
llvm::enumerate(genericOp.getIteratorTypesArray())) {
623+
if (insertSplitDimension == index)
624+
newIteratorTypes.push_back(utils::IteratorType::parallel);
625+
newIteratorTypes.push_back(iteratorType);
626+
}
627+
if (insertSplitDimension == genericOp.getIteratorTypesArray().size()) {
628+
newIteratorTypes.push_back(utils::IteratorType::parallel);
629+
}
630+
631+
// Create partial linalg.generic op with global index computation.
632+
Value tileSize = rewriter.create<arith::ConstantIndexOp>(loc, ratio);
633+
auto partialOp = rewriter.create<linalg::GenericOp>(
634+
loc, TypeRange{identityVal.getType(), identityIndex.getType()}, newInputs,
635+
ValueRange{identityVal, identityIndex}, newMaps, newIteratorTypes);
636+
637+
rewriter.inlineRegionBefore(genericOp.getRegion(), partialOp.getRegion(),
638+
partialOp.getRegion().begin());
639+
640+
Block &body = partialOp.getRegion().front();
641+
rewriter.setInsertionPointToStart(&body);
642+
643+
unsigned innerIdxDim = reductionDim + 1;
644+
unsigned outerIdxDim = insertSplitDimension;
645+
646+
// Compute global index (gidx) for reduction when the original reduction
647+
// dimension is split into [outerIdx, innerIdx] using `ratio`. This is used to
648+
// correctly compute the global index for comparisons and index selection.
649+
Value outerIdx = rewriter.create<linalg::IndexOp>(loc, outerIdxDim);
650+
Value innerIdx = rewriter.create<linalg::IndexOp>(loc, innerIdxDim);
651+
Value offset = rewriter.create<arith::MulIOp>(loc, outerIdx, tileSize);
652+
Value gidx = rewriter.create<arith::AddIOp>(loc, offset, innerIdx);
653+
654+
auto selectOp = dyn_cast<arith::SelectOp>(combinerOps.selectOp);
655+
Value oldIdx = selectOp.getTrueValue();
656+
Value newIdx = gidx;
657+
if (oldIdx.getType() != gidx.getType()) {
658+
newIdx = rewriter.create<arith::IndexCastOp>(loc, oldIdx.getType(), gidx);
659+
}
660+
selectOp.setOperand(1, newIdx);
661+
rewriter.setInsertionPointAfter(partialOp);
662+
663+
unsigned intermRank = newOutputShape.size();
664+
AffineMap valueMap = rewriter.getMultiDimIdentityMap(intermRank);
665+
AffineMap indexMap = valueMap;
666+
SmallVector<utils::IteratorType> reductionIteratorTypes;
667+
SmallVector<AffineExpr> resultExprs;
668+
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
669+
if (insertSplitIndex == i) {
670+
reductionIteratorTypes.push_back(utils::IteratorType::reduction);
671+
} else {
672+
resultExprs.push_back(rewriter.getAffineDimExpr(i));
673+
reductionIteratorTypes.push_back(utils::IteratorType::parallel);
674+
}
675+
}
676+
677+
AffineMap outputMap =
678+
AffineMap::get(intermRank, 0, resultExprs, rewriter.getContext());
679+
SmallVector<AffineMap> finalReductionMaps = {valueMap, indexMap, outputMap,
680+
outputMap};
681+
682+
// Create block for final reduction region.
683+
auto finalReduction = rewriter.create<linalg::GenericOp>(
684+
loc, genericOp.getResultTypes(),
685+
ValueRange{partialOp.getResult(0), partialOp.getResult(1)},
686+
genericOp.getDpsInits(), finalReductionMaps, reductionIteratorTypes,
687+
[combinerOps](OpBuilder &b, Location loc, ValueRange inputs) {
688+
Operation *clonedMax = b.clone(*combinerOps.maxOp);
689+
clonedMax->setOperands({inputs[0], inputs[2]});
690+
Operation *clonedCmp = b.clone(*combinerOps.cmpOp);
691+
clonedCmp->setOperands({inputs[0], inputs[2]});
692+
Operation *clonedSel = b.clone(*combinerOps.selectOp);
693+
clonedSel->setOperands({clonedCmp->getResult(0), inputs[1], inputs[3]});
694+
b.create<linalg::YieldOp>(
695+
loc, ValueRange{clonedMax->getResult(0), clonedSel->getResult(0)});
696+
});
697+
698+
rewriter.replaceOp(genericOp, finalReduction.getResults());
699+
// Init or alloc and fillOp are not applicable for argmax op; set to nullptr.
700+
return linalg::SplitReductionResult{
701+
/*initOrAlloc=*/nullptr, /*fillOp=*/nullptr,
702+
cast<linalg::LinalgOp>(partialOp.getOperation()), finalReduction};
703+
}
704+
392705
} // namespace mlir::iree_compiler::IREE::LinalgExt

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h

+60
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,64 @@ FailureOr<std::pair<Value, Value>> rewriteFft(Operation *op, Value operand,
118118
int64_t fftLength,
119119
PatternRewriter &rewriter);
120120

121+
/// Apply transformation to split a linalg.generic argmax reduction
122+
/// into a two-stage reduction using an additional parallel dimension.
123+
/// The transformation first computes a partial argmax over tiles (parallel),
124+
/// then reduces those results into a final result (reduction).
125+
///
126+
/// This pattern is specialized for reductions that yield both the maximum
127+
/// value and its index, using the combination of `arith.maximumf`,
128+
/// `arith.cmpf`, and `arith.select` ops. It assumes a known structure of the
129+
/// region and injects index computations to track global indices.
130+
///
131+
/// Returns the resulting partial and final linalg.generic ops, or failure
132+
/// if the pattern does not match or cannot be split.
133+
///
134+
/// Example:
135+
/// ```
136+
/// // Original argmax op reducing over dim=512
137+
/// %4:2 = linalg.generic {
138+
/// indexing_maps = [...],
139+
/// iterator_types = ["parallel", "reduction"]
140+
/// } ins(%arg0 : tensor<?x512xbf16>)
141+
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
142+
/// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
143+
/// %idx = linalg.index 1 : index
144+
/// %cast = arith.index_cast %idx : index to i64
145+
/// %max = arith.maximumf %in, %out : bf16
146+
/// %cmp = arith.cmpf ogt, %in, %out : bf16
147+
/// %sel = arith.select %cmp, %cast, %out_0 : i64
148+
/// linalg.yield %max, %sel : bf16, i64
149+
/// } -> (tensor<?xbf16>, tensor<?xi64>)
150+
/// ```
151+
/// To:
152+
/// ```
153+
/// // After splitting K=512 into 4 x 128
154+
/// %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x512xbf16>
155+
/// into tensor<?x4x128xbf16>
156+
/// %init_val = linalg.fill ins(%cst : bf16) outs(%empty : tensor<?x4xbf16>)
157+
/// %init_idx = linalg.fill ins(%zero : i64) outs(%empty : tensor<?x4xi64>)
158+
/// %partial:2 = linalg.generic {
159+
/// indexing_maps = [...],
160+
/// iterator_types = ["parallel", "reduction"]
161+
/// } ins(%expanded : tensor<?x4x128xbf16>)
162+
/// outs(%init_val, %init_idx : tensor<?x4xbf16>, tensor<?x4xi64>) {
163+
/// // compute global index: outer_idx * 128 + inner_idx
164+
/// ...
165+
/// }
166+
///
167+
/// // Final argmax over the tile dimension (dim=1 of ?x4)
168+
/// %final:2 = linalg.generic {
169+
/// indexing_maps = [...],
170+
/// iterator_types = ["reduction"]
171+
/// } ins(%partial#0, %partial#1)
172+
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
173+
/// // same combiner: maximumf, cmpf, select
174+
/// ...
175+
/// }
176+
/// ```
177+
FailureOr<linalg::SplitReductionResult>
178+
splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp,
179+
linalg::ControlSplitReductionFn controlSplitReductionFn);
180+
121181
}; // namespace mlir::iree_compiler::IREE::LinalgExt

0 commit comments

Comments
 (0)