Skip to content

Commit d803e06

Browse files
committed
add flag to enable default reduction pattern
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
1 parent 231a27c commit d803e06

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

compiler/src/iree/compiler/DispatchCreation/SplitReduction.cpp

+32-13
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ static llvm::cl::opt<int64_t>
2828
splitMatmulReductionRatio("iree-dispatch-creation-split-matmul-reduction",
2929
llvm::cl::desc("split ratio"), llvm::cl::init(1));
3030

31+
static llvm::cl::opt<bool> enableDefaultArgmaxSplitPattern(
32+
"iree-dispatch-creation-enable-default-argmax-split-pattern",
33+
llvm::cl::desc("Enable default argmax split-k for known reduction patterns "
34+
"with a reduction dimension of 128."),
35+
llvm::cl::init(true));
36+
37+
// Controls the tile size used when applying split-k to argmax reductions.
38+
// This value defines how many elements along the reduction dimension are
39+
// processed per tile (e.g., a value of 128 means each tile reduces over 128
40+
// elements).
3141
static llvm::cl::opt<int64_t> splitArgmaxReductionRatio(
3242
"iree-dispatch-creation-split-argmax-reduction",
3343
llvm::cl::desc("Ratio to split argmax. Set to 0 or 1 to disable"),
@@ -126,6 +136,11 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
126136

127137
SmallVector<int64_t, 4> loopRanges = genericOp.getStaticLoopRanges();
128138
int64_t reductionDimSize = loopRanges[reductionDim];
139+
140+
// The total number of output elements along this new dimension is
141+
// reductionDimSize / ratio.
142+
int64_t output_dimsize = reductionDimSize / ratio;
143+
129144
if (reductionDimSize == ShapedType::kDynamic ||
130145
reductionDimSize % ratio != 0) {
131146
return rewriter.notifyMatchFailure(
@@ -164,16 +179,16 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
164179
unsigned dim = map.getDimPosition(idx);
165180
if (reductionDim == dim) {
166181
if (control.innerParallel) {
182+
newShape.push_back(ratio); // reduce
167183
newShape.push_back(genericOp.getShape(operand)[idx] /
168-
ratio); // reduce
169-
newShape.push_back(ratio); // parallel (insert)
184+
ratio); // parallel (insert)
170185
exprs.push_back(rewriter.getAffineDimExpr(
171186
dim < insertSplitDimension ? dim : dim + 1));
172187
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
173188
} else {
174-
newShape.push_back(ratio); // parallel (insert)
175189
newShape.push_back(genericOp.getShape(operand)[idx] /
176-
ratio); // reduce
190+
ratio); // parallel (insert)
191+
newShape.push_back(ratio); // reduce
177192
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
178193
exprs.push_back(rewriter.getAffineDimExpr(
179194
dim < insertSplitDimension ? dim : dim + 1));
@@ -213,7 +228,7 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
213228
SmallVector<AffineExpr> outputExpr;
214229
for (unsigned idx = 0; idx <= oldShape.size(); ++idx) {
215230
if (idx == insertSplitIndex) {
216-
thisOutputShape.push_back(ratio);
231+
thisOutputShape.push_back(output_dimsize);
217232
outputExpr.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
218233
}
219234
if (idx < oldShape.size()) {
@@ -277,8 +292,7 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
277292
}
278293

279294
// Create partial linalg.generic op with global index computation.
280-
Value tileSize =
281-
rewriter.create<arith::ConstantIndexOp>(loc, reductionDimSize / ratio);
295+
Value tileSize = rewriter.create<arith::ConstantIndexOp>(loc, ratio);
282296
auto partialOp = rewriter.create<linalg::GenericOp>(
283297
loc, TypeRange{identityVal.getType(), identityIndex.getType()}, newInputs,
284298
ValueRange{identityVal, identityIndex}, newMaps, newIteratorTypes);
@@ -421,10 +435,19 @@ struct SplitReductionPass final
421435
void runOnOperation() override {
422436
if (splitMatmulReductionRatio.getValue() <= 1 &&
423437
topkSplitReductionRatio.empty() &&
424-
splitArgmaxReductionRatio.getValue() <= 1) {
438+
(splitArgmaxReductionRatio.getValue() <= 1 &&
439+
!enableDefaultArgmaxSplitPattern)) {
425440
return;
426441
}
427442

443+
if (enableDefaultArgmaxSplitPattern) {
444+
// Use default split-k pattern for argmax ops: split the reduction dim
445+
// (e.g., 131072) into tiles of size 128, resulting in 1024 tiles (131072
446+
// / 128). So, the split ratio refers to the tile size of the reduction
447+
// dimension.
448+
splitArgmaxReductionRatio = 128;
449+
}
450+
428451
MLIRContext *context = &getContext();
429452
auto funcOp = getOperation();
430453

@@ -471,11 +494,7 @@ struct SplitReductionPass final
471494
/*innerParallel=*/false};
472495
};
473496
for (auto op : argmaxCandidates) {
474-
if (failed(splitReductionWrapper(rewriter, op,
475-
argmaxSplitReductionControlFn))) {
476-
op.emitOpError("failed to split argmax operation");
477-
return signalPassFailure();
478-
}
497+
(void)splitReductionWrapper(rewriter, op, argmaxSplitReductionControlFn);
479498
}
480499

481500
// Split topk ops.

compiler/src/iree/compiler/DispatchCreation/test/split_argmax_reduction.mlir

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt --pass-pipeline='builtin.module(util.func(iree-dispatch-creation-split-reduction-ops,cse, canonicalize))' --iree-dispatch-creation-split-argmax-reduction=4 %s | FileCheck %s
1+
// RUN: iree-opt --pass-pipeline='builtin.module(util.func(iree-dispatch-creation-split-reduction-ops,cse, canonicalize))' %s | FileCheck %s
22

33
util.func public @argmax(%arg0: tensor<?x131072xbf16>, %arg1: index) -> tensor<?xi64> {
44
%cst = arith.constant 0xFF80 : bf16
@@ -35,16 +35,16 @@ util.func public @argmax(%arg0: tensor<?x131072xbf16>, %arg1: index) -> tensor<?
3535
// CHECK: %[[FINALIDX:.+]] = linalg.fill ins(%[[ZERO]] : i64) outs(%[[FINALIDX_EMPTY]] : tensor<?xi64>) -> tensor<?xi64>
3636

3737
// Check partial reduction
38-
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg0 {{\[}}[0], [1, 2]] output_shape [%{{.+}}, 4, 32768] : tensor<?x131072xbf16> into tensor<?x4x32768xbf16>
39-
// CHECK: %[[INITVAL:.+]] = tensor.empty(%{{.+}}) : tensor<?x4xbf16>
40-
// CHECK: %[[FILLVAL:.+]] = linalg.fill ins(%{{.+}} : bf16) outs(%[[INITVAL]] : tensor<?x4xbf16>) -> tensor<?x4xbf16>
41-
// CHECK: %[[INITIDX:.+]] = tensor.empty(%{{.+}}) : tensor<?x4xi64>
42-
// CHECK: %[[FILLIDX:.+]] = linalg.fill ins(%{{.+}} : i64) outs(%[[INITIDX]] : tensor<?x4xi64>) -> tensor<?x4xi64>
38+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg0 {{\[}}[0], [1, 2]] output_shape [%{{.+}}, 1024, 128] : tensor<?x131072xbf16> into tensor<?x1024x128xbf16>
39+
// CHECK: %[[INITVAL:.+]] = tensor.empty(%{{.+}}) : tensor<?x1024xbf16>
40+
// CHECK: %[[FILLVAL:.+]] = linalg.fill ins(%{{.+}} : bf16) outs(%[[INITVAL]] : tensor<?x1024xbf16>) -> tensor<?x1024xbf16>
41+
// CHECK: %[[INITIDX:.+]] = tensor.empty(%{{.+}}) : tensor<?x1024xi64>
42+
// CHECK: %[[FILLIDX:.+]] = linalg.fill ins(%{{.+}} : i64) outs(%[[INITIDX]] : tensor<?x1024xi64>) -> tensor<?x1024xi64>
4343

4444
// CHECK: %[[PARTIAL:.+]]:2 = linalg.generic
4545
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
46-
// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x4x32768xbf16>)
47-
// CHECK-SAME: outs(%[[FILLVAL]], %[[FILLIDX]] : tensor<?x4xbf16>, tensor<?x4xi64>)
46+
// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x1024x128xbf16>)
47+
// CHECK-SAME: outs(%[[FILLVAL]], %[[FILLIDX]] : tensor<?x1024xbf16>, tensor<?x1024xi64>)
4848
// CHECK: ^bb0(%[[VAL:.+]]: bf16, %[[ACC:.+]]: bf16, %[[IDX:.+]]: i64)
4949
// CHECK: %[[OUTER:.+]] = linalg.index 1 : index
5050
// CHECK: %[[INNER:.+]] = linalg.index 2 : index
@@ -59,7 +59,7 @@ util.func public @argmax(%arg0: tensor<?x131072xbf16>, %arg1: index) -> tensor<?
5959
// Final reduction
6060
// CHECK: %[[FINAL:.+]]:2 = linalg.generic
6161
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
62-
// CHECK-SAME: ins(%[[PARTIAL]]#0, %[[PARTIAL]]#1 : tensor<?x4xbf16>, tensor<?x4xi64>)
62+
// CHECK-SAME: ins(%[[PARTIAL]]#0, %[[PARTIAL]]#1 : tensor<?x1024xbf16>, tensor<?x1024xi64>)
6363
// CHECK-SAME: outs(%[[FINALVAL]], %[[FINALIDX]] : tensor<?xbf16>, tensor<?xi64>)
6464
// CHECK: ^bb0(%[[V1:.+]]: bf16, %[[I1:.+]]: i64, %[[V2:.+]]: bf16, %[[I2:.+]]: i64)
6565
// CHECK: %[[MAX2:.+]] = arith.maximumf %[[V1]], %[[V2]] : bf16

0 commit comments

Comments
 (0)