@@ -920,4 +920,321 @@ bool isArgmaxOp(linalg::GenericOp genericOp) {
920
920
return true ;
921
921
}
922
922
923
+ struct ArgmaxCombinerOps {
924
+ Operation *maxOp = nullptr ; // arith.maximumf
925
+ Operation *selectOp = nullptr ; // arith.select
926
+ Operation *cmpOp = nullptr ; // arith.cmpf
927
+ };
928
+
929
+ // Matches the combiner pattern in a linalg.generic argmax-style reduction:
930
+ // Example MLIR:
931
+ // %4:2 = linalg.generic {
932
+ // indexing_maps = [...],
933
+ // iterator_types = ["parallel", "reduction"]
934
+ // } ins(%arg0 : tensor<?x128xbf16>) outs(%1, %3 : tensor<?xbf16>,
935
+ // tensor<?xi64>) {
936
+ // ^bb0(%in: bf16, %out: bf16, %out_0: i64):
937
+ // %5 = linalg.index 1 : index
938
+ // %6 = arith.index_cast %5 : index to i64
939
+ // %7 = arith.maximumf %in, %out : bf16
940
+ // %8 = arith.cmpf ogt, %in, %out : bf16
941
+ // %9 = arith.select %8, %6, %out_0 : i64
942
+ // linalg.yield %7, %9 : bf16, i64
943
+ // } -> (tensor<?xbf16>, tensor<?xi64>)
944
+ //
945
+ // This function extracts the `arith.maximumf`, `arith.cmpf`, and `arith.select`
946
+ // operations from the body to facilitate transformations such as split
947
+ // reduction.
948
+ static FailureOr<ArgmaxCombinerOps>
949
+ collectArgmaxCombinerOps (linalg::GenericOp genericOp) {
950
+ // if (combinerOps.size() < 3) {
951
+ // return genericOp->emitError(
952
+ // "combinerOps must have space for exactly 3 elements");
953
+ // }
954
+
955
+ assert (IREE::LinalgExt::isArgmaxOp (genericOp) &&
956
+ " expected operation to be an argmax op" );
957
+
958
+ ArgmaxCombinerOps ops;
959
+
960
+ auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody ()->getTerminator ());
961
+
962
+ // Extract max value producer: arith.maximumf.
963
+ Value maxResult = yieldOp.getOperand (0 );
964
+ auto maxOp = dyn_cast<arith::MaximumFOp>(maxResult.getDefiningOp ());
965
+
966
+ // Extract index result producer: arith.select.
967
+ Value indexResult = yieldOp.getOperand (1 );
968
+ auto selectOp = dyn_cast<arith::SelectOp>(indexResult.getDefiningOp ());
969
+
970
+ // Extract the condition of the select, expected to be arith.cmpf with
971
+ // predicate OGT.
972
+ auto cmpOp = dyn_cast<arith::CmpFOp>(selectOp.getCondition ().getDefiningOp ());
973
+
974
+ ops.maxOp = maxOp;
975
+ ops.selectOp = selectOp;
976
+ ops.cmpOp = cmpOp;
977
+
978
+ return ops;
979
+ }
980
+
981
+ FailureOr<linalg::SplitReductionResult>
982
+ splitArgmaxReduction (RewriterBase &rewriter, linalg::GenericOp genericOp,
983
+ linalg::ControlSplitReductionFn controlSplitReductionFn) {
984
+ assert (IREE::LinalgExt::isArgmaxOp (genericOp) &&
985
+ " expected operation to be an argmax op" );
986
+
987
+ OpBuilder::InsertionGuard guard (rewriter);
988
+ rewriter.setInsertionPoint (genericOp);
989
+ Location loc = genericOp->getLoc ();
990
+
991
+ linalg::SplitReductionOptions control = controlSplitReductionFn (genericOp);
992
+ int64_t ratio = control.ratio ;
993
+ unsigned insertSplitIndex = control.index ;
994
+ unsigned insertSplitDimension = control.index ;
995
+ if (ratio <= 1 ) {
996
+ return rewriter.notifyMatchFailure (
997
+ genericOp, " split ratio needs to be greater than 1" );
998
+ }
999
+
1000
+ SmallVector<unsigned > dims;
1001
+ genericOp.getReductionDims (dims);
1002
+
1003
+ unsigned reductionDim = dims[0 ];
1004
+ if (control.innerParallel ) {
1005
+ insertSplitDimension = reductionDim + 1 ;
1006
+ }
1007
+
1008
+ SmallVector<int64_t , 4 > loopRanges = genericOp.getStaticLoopRanges ();
1009
+ int64_t reductionDimSize = loopRanges[reductionDim];
1010
+
1011
+ // The total number of output elements along this new dimension is
1012
+ // reductionDimSize / ratio.
1013
+ int64_t outputDimsize = reductionDimSize / ratio;
1014
+
1015
+ if (reductionDimSize == ShapedType::kDynamic ||
1016
+ reductionDimSize % ratio != 0 ) {
1017
+ return rewriter.notifyMatchFailure (
1018
+ genericOp, " Reduction dimension not divisible by split ratio" );
1019
+ }
1020
+
1021
+ if (insertSplitIndex >
1022
+ genericOp.getShape (genericOp.getDpsInitOperand (0 )).size ()) {
1023
+ return rewriter.notifyMatchFailure (genericOp,
1024
+ " Insert dimension position too large "
1025
+ " compared to intermediate tensor size" );
1026
+ }
1027
+
1028
+ FailureOr<ArgmaxCombinerOps> maybeOps = collectArgmaxCombinerOps (genericOp);
1029
+ if (failed (maybeOps))
1030
+ return rewriter.notifyMatchFailure (genericOp,
1031
+ " invalid combiner for argmax" );
1032
+
1033
+ ArgmaxCombinerOps combinerOps = *maybeOps;
1034
+ Operation *reductionOp = combinerOps.maxOp ;
1035
+
1036
+ std::optional<TypedAttr> identity = arith::getNeutralElement (reductionOp);
1037
+ if (!identity.has_value ())
1038
+ return rewriter.notifyMatchFailure (
1039
+ genericOp, " Unknown identity value for the reduction" );
1040
+
1041
+ SmallVector<Value> newInputs;
1042
+ SmallVector<AffineMap> newMaps;
1043
+ // Calculate the new shapes and indexing maps of the input operands.
1044
+ for (OpOperand *operand : genericOp.getDpsInputOperands ()) {
1045
+ AffineMap map = genericOp.getMatchingIndexingMap (operand);
1046
+ SmallVector<int64_t > newShape;
1047
+ SmallVector<AffineExpr> exprs;
1048
+ SmallVector<ReassociationIndices> reassociation;
1049
+ unsigned index = 0 ;
1050
+ for (unsigned idx : llvm::seq<unsigned >(0 , map.getNumResults ())) {
1051
+ unsigned dim = map.getDimPosition (idx);
1052
+ if (reductionDim == dim) {
1053
+ if (control.innerParallel ) {
1054
+ newShape.push_back (ratio); // reduce
1055
+ newShape.push_back (genericOp.getShape (operand)[idx] /
1056
+ ratio); // parallel (insert)
1057
+ exprs.push_back (rewriter.getAffineDimExpr (
1058
+ dim < insertSplitDimension ? dim : dim + 1 ));
1059
+ exprs.push_back (rewriter.getAffineDimExpr (insertSplitDimension));
1060
+ } else {
1061
+ newShape.push_back (genericOp.getShape (operand)[idx] /
1062
+ ratio); // parallel (insert)
1063
+ newShape.push_back (ratio); // reduce
1064
+ exprs.push_back (rewriter.getAffineDimExpr (insertSplitDimension));
1065
+ exprs.push_back (rewriter.getAffineDimExpr (
1066
+ dim < insertSplitDimension ? dim : dim + 1 ));
1067
+ }
1068
+ reassociation.push_back ({index ++, index ++});
1069
+ continue ;
1070
+ }
1071
+ newShape.push_back (genericOp.getShape (operand)[idx]);
1072
+ exprs.push_back (rewriter.getAffineDimExpr (
1073
+ dim < insertSplitDimension ? dim : dim + 1 ));
1074
+ reassociation.push_back ({index ++});
1075
+ }
1076
+ newMaps.push_back (
1077
+ AffineMap::get (map.getNumDims () + 1 , 0 , exprs, genericOp.getContext ()));
1078
+ // If the shape is unchanged the input doesn't change.
1079
+ if (newShape == genericOp.getShape (operand)) {
1080
+ newInputs.push_back (operand->get ());
1081
+ continue ;
1082
+ }
1083
+ Type newType = RankedTensorType::get (
1084
+ newShape,
1085
+ cast<RankedTensorType>(operand->get ().getType ()).getElementType ());
1086
+
1087
+ Value newInput = rewriter.create <tensor::ExpandShapeOp>(
1088
+ loc, newType, operand->get (), reassociation);
1089
+ newInputs.push_back (newInput);
1090
+ }
1091
+
1092
+ SmallVector<SmallVector<int64_t >> newOutputShapes;
1093
+ SmallVector<AffineMap> outputMaps;
1094
+ for (int i = 0 ; i < genericOp.getNumDpsInits (); ++i) {
1095
+ OpOperand *output = genericOp.getDpsInitOperand (i);
1096
+ AffineMap oldOutputMap = genericOp.getMatchingIndexingMap (output);
1097
+ ArrayRef<int64_t > oldShape = genericOp.getShape (output);
1098
+ SmallVector<int64_t > thisOutputShape;
1099
+
1100
+ SmallVector<AffineExpr> outputExpr;
1101
+ for (unsigned idx = 0 ; idx <= oldShape.size (); ++idx) {
1102
+ if (idx == insertSplitIndex) {
1103
+ thisOutputShape.push_back (outputDimsize);
1104
+ outputExpr.push_back (rewriter.getAffineDimExpr (insertSplitDimension));
1105
+ }
1106
+ if (idx < oldShape.size ()) {
1107
+ thisOutputShape.push_back (oldShape[idx]);
1108
+ unsigned dim = oldOutputMap.getDimPosition (idx);
1109
+ outputExpr.push_back (rewriter.getAffineDimExpr (
1110
+ dim < insertSplitDimension ? dim : dim + 1 ));
1111
+ }
1112
+ }
1113
+
1114
+ AffineMap newOutputMap = AffineMap::get (oldOutputMap.getNumDims () + 1 , 0 ,
1115
+ outputExpr, rewriter.getContext ());
1116
+ newMaps.push_back (newOutputMap);
1117
+ newOutputShapes.push_back (thisOutputShape);
1118
+ }
1119
+
1120
+ // Handle dynamic dimensions for identity value tensor.
1121
+ SmallVector<Value> dynValDims;
1122
+ SmallVector<int64_t > newOutputShape = newOutputShapes[0 ];
1123
+ for (size_t i = 0 ; i < newOutputShape.size (); ++i) {
1124
+ if (ShapedType::isDynamic (newOutputShape[i])) {
1125
+ dynValDims.push_back (rewriter.create <tensor::DimOp>(
1126
+ loc, genericOp.getDpsInputOperand (0 )->get (), i));
1127
+ }
1128
+ }
1129
+
1130
+ Type valueElemType = genericOp.getRegionOutputArgs ()[0 ].getType ();
1131
+ Value emptyValTensor = rewriter.create <tensor::EmptyOp>(
1132
+ loc, newOutputShape, valueElemType, dynValDims);
1133
+ Value constantOp = rewriter.create <arith::ConstantOp>(loc, *identity);
1134
+ Value identityVal =
1135
+ rewriter.create <linalg::FillOp>(loc, constantOp, emptyValTensor)
1136
+ .getResult (0 );
1137
+
1138
+ // Handle dynamic dimensions for identity index tensor.
1139
+ SmallVector<Value> dynIdxDims;
1140
+ newOutputShape = newOutputShapes[1 ];
1141
+ for (size_t i = 0 ; i < newOutputShape.size (); ++i) {
1142
+ if (ShapedType::isDynamic (newOutputShape[i])) {
1143
+ dynIdxDims.push_back (rewriter.create <tensor::DimOp>(
1144
+ loc, genericOp.getDpsInputOperand (0 )->get (), i));
1145
+ }
1146
+ }
1147
+ Type idxElemType = genericOp.getRegionOutputArgs ()[1 ].getType ();
1148
+ Value zeroIdx = rewriter.create <arith::ConstantOp>(
1149
+ loc, rewriter.getZeroAttr (idxElemType));
1150
+ Value idxInitTensor = rewriter.create <tensor::EmptyOp>(
1151
+ loc, newOutputShape, idxElemType, dynIdxDims);
1152
+ Value identityIndex =
1153
+ rewriter.create <linalg::FillOp>(loc, zeroIdx, idxInitTensor).getResult (0 );
1154
+
1155
+ SmallVector<utils::IteratorType> newIteratorTypes;
1156
+ for (auto [index , iteratorType] :
1157
+ llvm::enumerate (genericOp.getIteratorTypesArray ())) {
1158
+ if (insertSplitDimension == index )
1159
+ newIteratorTypes.push_back (utils::IteratorType::parallel);
1160
+ newIteratorTypes.push_back (iteratorType);
1161
+ }
1162
+ if (insertSplitDimension == genericOp.getIteratorTypesArray ().size ()) {
1163
+ newIteratorTypes.push_back (utils::IteratorType::parallel);
1164
+ }
1165
+
1166
+ // Create partial linalg.generic op with global index computation.
1167
+ Value tileSize = rewriter.create <arith::ConstantIndexOp>(loc, ratio);
1168
+ auto partialOp = rewriter.create <linalg::GenericOp>(
1169
+ loc, TypeRange{identityVal.getType (), identityIndex.getType ()}, newInputs,
1170
+ ValueRange{identityVal, identityIndex}, newMaps, newIteratorTypes);
1171
+
1172
+ rewriter.inlineRegionBefore (genericOp.getRegion (), partialOp.getRegion (),
1173
+ partialOp.getRegion ().begin ());
1174
+
1175
+ Block &body = partialOp.getRegion ().front ();
1176
+ rewriter.setInsertionPointToStart (&body);
1177
+
1178
+ unsigned innerIdxDim = reductionDim + 1 ;
1179
+ unsigned outerIdxDim = insertSplitDimension;
1180
+
1181
+ // Compute global index (gidx) for reduction when the original reduction
1182
+ // dimension is split into [outerIdx, innerIdx] using `ratio`. This is used to
1183
+ // correctly compute the global index for comparisons and index selection.
1184
+ Value outerIdx = rewriter.create <linalg::IndexOp>(loc, outerIdxDim);
1185
+ Value innerIdx = rewriter.create <linalg::IndexOp>(loc, innerIdxDim);
1186
+ Value offset = rewriter.create <arith::MulIOp>(loc, outerIdx, tileSize);
1187
+ Value gidx = rewriter.create <arith::AddIOp>(loc, offset, innerIdx);
1188
+
1189
+ auto selectOp = dyn_cast<arith::SelectOp>(combinerOps.selectOp );
1190
+ Value oldIdx = selectOp.getTrueValue ();
1191
+ Value newIdx = gidx;
1192
+ if (oldIdx.getType () != gidx.getType ()) {
1193
+ newIdx = rewriter.create <arith::IndexCastOp>(loc, oldIdx.getType (), gidx);
1194
+ }
1195
+ selectOp.setOperand (1 , newIdx);
1196
+ rewriter.setInsertionPointAfter (partialOp);
1197
+
1198
+ unsigned intermRank = newOutputShape.size ();
1199
+ AffineMap valueMap = rewriter.getMultiDimIdentityMap (intermRank);
1200
+ AffineMap indexMap = valueMap;
1201
+ SmallVector<utils::IteratorType> reductionIteratorTypes;
1202
+ SmallVector<AffineExpr> resultExprs;
1203
+ for (unsigned i : llvm::seq<unsigned >(0 , intermRank)) {
1204
+ if (insertSplitIndex == i) {
1205
+ reductionIteratorTypes.push_back (utils::IteratorType::reduction);
1206
+ } else {
1207
+ resultExprs.push_back (rewriter.getAffineDimExpr (i));
1208
+ reductionIteratorTypes.push_back (utils::IteratorType::parallel);
1209
+ }
1210
+ }
1211
+
1212
+ AffineMap outputMap =
1213
+ AffineMap::get (intermRank, 0 , resultExprs, rewriter.getContext ());
1214
+ SmallVector<AffineMap> finalReductionMaps = {valueMap, indexMap, outputMap,
1215
+ outputMap};
1216
+
1217
+ // Create block for final reduction region.
1218
+ auto finalReduction = rewriter.create <linalg::GenericOp>(
1219
+ loc, genericOp.getResultTypes (),
1220
+ ValueRange{partialOp.getResult (0 ), partialOp.getResult (1 )},
1221
+ genericOp.getDpsInits (), finalReductionMaps, reductionIteratorTypes,
1222
+ [combinerOps](OpBuilder &b, Location loc, ValueRange inputs) {
1223
+ Operation *clonedMax = b.clone (*combinerOps.maxOp );
1224
+ clonedMax->setOperands ({inputs[0 ], inputs[2 ]});
1225
+ Operation *clonedCmp = b.clone (*combinerOps.cmpOp );
1226
+ clonedCmp->setOperands ({inputs[0 ], inputs[2 ]});
1227
+ Operation *clonedSel = b.clone (*combinerOps.selectOp );
1228
+ clonedSel->setOperands ({clonedCmp->getResult (0 ), inputs[1 ], inputs[3 ]});
1229
+ b.create <linalg::YieldOp>(
1230
+ loc, ValueRange{clonedMax->getResult (0 ), clonedSel->getResult (0 )});
1231
+ });
1232
+
1233
+ rewriter.replaceOp (genericOp, finalReduction.getResults ());
1234
+ // Init or alloc and fillOp are not applicable for argmax op; set to nullptr.
1235
+ return linalg::SplitReductionResult{
1236
+ /* initOrAlloc=*/ nullptr , /* fillOp=*/ nullptr ,
1237
+ cast<linalg::LinalgOp>(partialOp.getOperation ()), finalReduction};
1238
+ }
1239
+
923
1240
} // namespace mlir::iree_compiler::IREE::LinalgExt
0 commit comments