@@ -28,6 +28,16 @@ static llvm::cl::opt<int64_t>
28
28
splitMatmulReductionRatio (" iree-dispatch-creation-split-matmul-reduction" ,
29
29
llvm::cl::desc (" split ratio" ), llvm::cl::init(1 ));
30
30
31
+ static llvm::cl::opt<bool > enableDefaultArgmaxSplitPattern (
32
+ " iree-dispatch-creation-enable-default-argmax-split-pattern" ,
33
+ llvm::cl::desc (" Enable default argmax split-k for known reduction patterns "
34
+ " with a reduction dimension of 128." ),
35
+ llvm::cl::init(true ));
36
+
37
+ // Controls the tile size used when applying split-k to argmax reductions.
38
+ // This value defines how many elements along the reduction dimension are
39
+ // processed per tile (e.g., a value of 128 means each tile reduces over 128
40
+ // elements).
31
41
static llvm::cl::opt<int64_t > splitArgmaxReductionRatio (
32
42
" iree-dispatch-creation-split-argmax-reduction" ,
33
43
llvm::cl::desc (" Ratio to split argmax. Set to 0 or 1 to disable" ),
@@ -126,6 +136,11 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
126
136
127
137
SmallVector<int64_t , 4 > loopRanges = genericOp.getStaticLoopRanges ();
128
138
int64_t reductionDimSize = loopRanges[reductionDim];
139
+
140
+ // The total number of output elements along this new dimension is
141
+ // reductionDimSize / ratio.
142
+ int64_t output_dimsize = reductionDimSize / ratio;
143
+
129
144
if (reductionDimSize == ShapedType::kDynamic ||
130
145
reductionDimSize % ratio != 0 ) {
131
146
return rewriter.notifyMatchFailure (
@@ -164,16 +179,16 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
164
179
unsigned dim = map.getDimPosition (idx);
165
180
if (reductionDim == dim) {
166
181
if (control.innerParallel ) {
182
+ newShape.push_back (ratio); // reduce
167
183
newShape.push_back (genericOp.getShape (operand)[idx] /
168
- ratio); // reduce
169
- newShape.push_back (ratio); // parallel (insert)
184
+ ratio); // parallel (insert)
170
185
exprs.push_back (rewriter.getAffineDimExpr (
171
186
dim < insertSplitDimension ? dim : dim + 1 ));
172
187
exprs.push_back (rewriter.getAffineDimExpr (insertSplitDimension));
173
188
} else {
174
- newShape.push_back (ratio); // parallel (insert)
175
189
newShape.push_back (genericOp.getShape (operand)[idx] /
176
- ratio); // reduce
190
+ ratio); // parallel (insert)
191
+ newShape.push_back (ratio); // reduce
177
192
exprs.push_back (rewriter.getAffineDimExpr (insertSplitDimension));
178
193
exprs.push_back (rewriter.getAffineDimExpr (
179
194
dim < insertSplitDimension ? dim : dim + 1 ));
@@ -213,7 +228,7 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
213
228
SmallVector<AffineExpr> outputExpr;
214
229
for (unsigned idx = 0 ; idx <= oldShape.size (); ++idx) {
215
230
if (idx == insertSplitIndex) {
216
- thisOutputShape.push_back (ratio );
231
+ thisOutputShape.push_back (output_dimsize );
217
232
outputExpr.push_back (rewriter.getAffineDimExpr (insertSplitDimension));
218
233
}
219
234
if (idx < oldShape.size ()) {
@@ -277,8 +292,7 @@ FailureOr<linalg::SplitReductionResult> splitReductionImpl<linalg::GenericOp>(
277
292
}
278
293
279
294
// Create partial linalg.generic op with global index computation.
280
- Value tileSize =
281
- rewriter.create <arith::ConstantIndexOp>(loc, reductionDimSize / ratio);
295
+ Value tileSize = rewriter.create <arith::ConstantIndexOp>(loc, ratio);
282
296
auto partialOp = rewriter.create <linalg::GenericOp>(
283
297
loc, TypeRange{identityVal.getType (), identityIndex.getType ()}, newInputs,
284
298
ValueRange{identityVal, identityIndex}, newMaps, newIteratorTypes);
@@ -421,10 +435,19 @@ struct SplitReductionPass final
421
435
void runOnOperation () override {
422
436
if (splitMatmulReductionRatio.getValue () <= 1 &&
423
437
topkSplitReductionRatio.empty () &&
424
- splitArgmaxReductionRatio.getValue () <= 1 ) {
438
+ (splitArgmaxReductionRatio.getValue () <= 1 &&
439
+ !enableDefaultArgmaxSplitPattern)) {
425
440
return ;
426
441
}
427
442
443
+ if (enableDefaultArgmaxSplitPattern) {
444
+ // Use default split-k pattern for argmax ops: split the reduction dim
445
+ // (e.g., 131072) into tiles of size 128, resulting in 1024 tiles (131072
446
+ // / 128). So, the split ratio refers to the tile size of the reduction
447
+ // dimension.
448
+ splitArgmaxReductionRatio = 128 ;
449
+ }
450
+
428
451
MLIRContext *context = &getContext ();
429
452
auto funcOp = getOperation ();
430
453
@@ -471,11 +494,7 @@ struct SplitReductionPass final
471
494
/* innerParallel=*/ false };
472
495
};
473
496
for (auto op : argmaxCandidates) {
474
- if (failed (splitReductionWrapper (rewriter, op,
475
- argmaxSplitReductionControlFn))) {
476
- op.emitOpError (" failed to split argmax operation" );
477
- return signalPassFailure ();
478
- }
497
+ (void )splitReductionWrapper (rewriter, op, argmaxSplitReductionControlFn);
479
498
}
480
499
481
500
// Split topk ops.
0 commit comments