Skip to content

[Codegen] split-k on argmax op #20717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

bangtianliu
Copy link
Contributor

@bangtianliu bangtianliu commented May 3, 2025

This PR adds support for split reduction on the argmax operation.

The main question is: why couldn't upstream mlir::linalg::splitReduction be reused? The answer will also implicitly show what this PR is about.
The upstream mlir::linalg::splitReduction implementation is designed for single-result reductions (e.g., sum, max) and assumes a specific combiner behavior. In contrast, argmax is a two-result reduction, where both the value and the index must be tracked across the split.

Supporting this would require significant changes compared to the upstream utility, The relevant changes made by this PR include:

  • Handle two outputs in the reduction body and yield
  • Preserve the value-index pairing semantics throughout the split and recombine phases
  • Match a more complex combiner structure (maximumf + cmpf+ select) and use all three combiner ops to generate the corresponding split and combine logic.

Other changes from this PR:

  • Format the utility function isArgmaxOp, move it from iree/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp to compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
  • Add support for dynamic dimensions in argmax, allowing parallel dimensions to be dynamic.

Issue: #20650

@bangtianliu bangtianliu marked this pull request as draft May 3, 2025 00:11
@bangtianliu bangtianliu force-pushed the split_k_argmax branch 4 times, most recently from 6f6e059 to ede09aa Compare May 4, 2025 20:55
@bangtianliu bangtianliu marked this pull request as ready for review May 4, 2025 20:57
@bangtianliu bangtianliu requested a review from pashu123 May 5, 2025 17:35
// Check identity value preparation
// CHECK: %[[CST:.+]] = arith.constant 0xFF80 : bf16
// CHECK: %[[ZERO:.+]] = arith.constant 0 : i64
// CHECK: %[[FINALVAL_EMPTY:.+]] = tensor.empty(%[[ARG1:.+]]) : tensor<?xbf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming you didnt need to create these operations, i.e they are the same as what existed before.

Copy link
Contributor Author

@bangtianliu bangtianliu May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are identity (or initial) value and index used for reduction.

Copy link
Contributor

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITs. I'll take a look once again.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to review the template part if you don't mind wait me a bit. I can do the review later today or tomorrow. No needs to wait for me if others approve the change. I can send post review comments or raise a PR myself after you land the PR.

@bangtianliu
Copy link
Contributor Author

I want to review the template part if you don't mind wait me a bit. I can do the review later today or tomorrow. No needs to wait for me if others approve the change. I can send post review comments or raise a PR myself after you land the PR.

Thanks! No worries — it doesn’t seem like this PR is urgent to land.

@bangtianliu
Copy link
Contributor Author

Per Mahesh's suggestion, this adds a new flag --iree-dispatch-creation-enable-default-argmax-split-pattern to enable argmax split-k when the reduction dimension is 128.

Additionally, the meaning of splitArgmaxReductionRatio is updated to represent the number of elements per split along the reduction dimension (i.e., the tile size), rather than the number of splits.

In my local experimental test, the runtime for a <1x131072xbf16> input was 9.50 ms with the default flag disabled. When the default split pattern (with a reduction dimension of 128) was enabled, the runtime decreased significantly to 0.187 ms.

@bangtianliu
Copy link
Contributor Author

It seems that enabling this default pattern will cause some CI errors. I will look into them.

@bangtianliu bangtianliu force-pushed the split_k_argmax branch 3 times, most recently from d54fcca to d803e06 Compare May 7, 2025 16:03
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not do a detailed review, but here are a few high-level comments so far:

  1. We already have separation in the main function, i.e., SplitReductionPass::runOnOperation. Can we avoid the template? Template is good for some cases, but the downside is hard to debug and maintain. It is also hard to trace code. I'd prefer to avoid it as much as possible. This falls to the case that we can avoid template. I don't see a need to use template here. Naming the method as splitArgmaxReduction looks much easier to me. https://google.github.io/styleguide/cppguide.html#Template_metaprogramming
  2. Can you elaborate why the upstream method can't be reused? I can see some reasons when I browse the code, but it is easier if you can list down why and what the difference is in PR description. It is also helpful for reviewers and future contributors, as they would track back to the commit.
  3. There are too many dup code between your implementation and upstream splitReduction. Can you think a way to refactor the code? I can help if you need some guidance.
  4. I think it is better to move the splitArgmaxReduction() to LinalgExt/Transforms/Transforms.[h|cpp]. Because (a) I feel that we will upstream it in the future, (b) this makes the pass file cleaner, (c) it belongs to transformations about LinalgExt (linalg extention), IMO, (d) it can be reused by other passes when needed. RE(d): we do have a good example of reusing splitReduction in CPU codegen to perform innermost reduction. I think it is better to expose the implementation outside this pass in the first place.
  5. Similar to (3), please add more details to the PR description. The PR is not trivial, we should have informative PR description. Also, it'd be a plus if you mention that there is a refactoring about isArgmaxOp in the PR description. https://google.github.io/eng-practices/review/developer/cl-descriptions.html

The topK variant locates at wrong files, I'll help fix it in a separate PR.

@bangtianliu bangtianliu force-pushed the split_k_argmax branch 2 times, most recently from 8c6750c to 53f44b3 Compare May 7, 2025 22:34
@bangtianliu
Copy link
Contributor Author

I did not do a detailed review, but here are a few high-level comments so far:

  1. We already have separation in the main function, i.e., SplitReductionPass::runOnOperation. Can we avoid the template? Template is good for some cases, but the downside is hard to debug and maintain. It is also hard to trace code. I'd prefer to avoid it as much as possible. This falls to the case that we can avoid template. I don't see a need to use template here. Naming the method as splitArgmaxReduction looks much easier to me. https://google.github.io/styleguide/cppguide.html#Template_metaprogramming
  2. Can you elaborate why the upstream method can't be reused? I can see some reasons when I browse the code, but it is easier if you can list down why and what the difference is in PR description. It is also helpful for reviewers and future contributors, as they would track back to the commit.
  3. There are too many dup code between your implementation and upstream splitReduction. Can you think a way to refactor the code? I can help if you need some guidance.
  4. I think it is better to move the splitArgmaxReduction() to LinalgExt/Transforms/Transforms.[h|cpp]. Because (a) I feel that we will upstream it in the future, (b) this makes the pass file cleaner, (c) it belongs to transformations about LinalgExt (linalg extention), IMO, (d) it can be reused by other passes when needed. RE(d): we do have a good example of reusing splitReduction in CPU codegen to perform innermost reduction. I think it is better to expose the implementation outside this pass in the first place.
  5. Similar to (3), please add more details to the PR description. The PR is not trivial, we should have informative PR description. Also, it'd be a plus if you mention that there is a refactoring about isArgmaxOp in the PR description. https://google.github.io/eng-practices/review/developer/cl-descriptions.html

The topK variant locates at wrong files, I'll help fix it in a separate PR.

Thanks for your time and feedback during the code reviews. I've removed the template, updated the PR description, and reformatted the code according to your suggestions—except for point 3. Could you please provide further guidance on that part?

@bangtianliu bangtianliu requested a review from hanhanW May 7, 2025 22:44
@bangtianliu
Copy link
Contributor Author

oh, I missed this one:

I think it is better to move the splitArgmaxReduction() to LinalgExt/Transforms/Transforms.[h|cpp]

I will send another commit to address it.

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu force-pushed the split_k_argmax branch 2 times, most recently from e59e177 to 45051bf Compare May 8, 2025 02:24
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some nits, and a big comment about how we generate expanded operands. Please take a look, thanks.

EDIT: I can't share the link before I submit the comment, here is the link to the big comment: #20717 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put the test to split_reduction.mlir. I only split the files when it has too many test cases. When it happens, I start thinking if the pass has more responsibilities than it should. There are exceptions, though. This case does not fit my mental model about the split. So I suggest moving the test to split_reduction.mlir, and it is easier for others to find the test. (I think it is also how we structure the tests.)

Comment on lines +91 to +97
if (enableStaticArgmaxSplit) {
// Use default split-k pattern for argmax ops: split the reduction dim
// (e.g., 131072) into tiles of size 128, resulting in 1024 tiles (131072
// / 128). So, the split ratio refers to the tile size of the reduction
// dimension.
splitArgmaxReductionRatio = 128;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the additional flag, but not just making iree-dispatch-creation-split-argmax-reduction default to 128? The other question is that do we want it happening by default? If so, we should document how 128 is derived. If it is an experimental number, we can also just mention it in the comment. People will ask why when they see the magic number, IMO.

My expectation is that the default value is 1 (and we don't need the enableStaticArgmaxSplit variable). You use --iree-dispatch-creation-split-argmax-reduction=128 in your compilation config. We can set the default value to a reasonable value when we are confident for this.

(I don't work on this area, so maybe we are already confident about this. Then it is okay to set it to 128 by default.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the additional flag, but not just making iree-dispatch-creation-split-argmax-reduction default to 128? The other question is that do we want it happening by default? If so, we should document how 128 is derived. If it is an experimental number, we can also just mention it in the comment. People will ask why when they see the magic number, IMO.

My expectation is that the default value is 1 (and we don't need the enableStaticArgmaxSplit variable). You use --iree-dispatch-creation-split-argmax-reduction=128 in your compilation config. We can set the default value to a reasonable value when we are confident for this.

(I don't work on this area, so maybe we are already confident about this. Then it is okay to set it to 128 by default.)

Here is the context: #20717 (comment)

We want to enable it by default so that the performance issue in #20650 can be solved.

auto matmulSplitReductionControlFn =
[&](linalg::LinalgOp op) -> linalg::SplitReductionOptions {
// For matmul make the new parallel dimension first so that it looks
// like a batch_matmul and can follow the same codegen.
return {int64_t(splitReductionRatio), 0, /*innerParallel=*/false};
return {int64_t(splitMatmulReductionRatio), 0, /*innerParallel=*/false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not your responsibility because the code was bad in the first place. Could you help remove the int64_t() cast since you're touching the code? I think we don't need the cast at all. The type is already int64_t.

(side note: ideally, if we really need the cast, we should use c++ style like static_cast. We should not use c style casting.)

Comment on lines +418 to +444
static FailureOr<ArgmaxCombinerOps>
collectArgmaxCombinerOps(linalg::GenericOp genericOp) {

assert(isArgmaxOp(genericOp) && "expected operation to be an argmax op");

ArgmaxCombinerOps ops;

auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());

// Extract max value producer: arith.maximumf.
Value maxResult = yieldOp.getOperand(0);
auto maxOp = dyn_cast<arith::MaximumFOp>(maxResult.getDefiningOp());

// Extract index result producer: arith.select.
Value indexResult = yieldOp.getOperand(1);
auto selectOp = dyn_cast<arith::SelectOp>(indexResult.getDefiningOp());

// Extract the condition of the select, expected to be arith.cmpf with
// predicate OGT.
auto cmpOp = dyn_cast<arith::CmpFOp>(selectOp.getCondition().getDefiningOp());

ops.maxOp = maxOp;
ops.selectOp = selectOp;
ops.cmpOp = cmpOp;

return ops;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit 1: don't use blank lines when you don't have to. resist starting functions with a blank line. https://google.github.io/styleguide/cppguide.html#Vertical_Whitespace
nit 2: ops is not used until the end. we should move the declaration down.
style nit 3: I'd use cast because you are not handling the 'failure'. We expect it is castable and we already have the assertion in the first line.

Suggested change
static FailureOr<ArgmaxCombinerOps>
collectArgmaxCombinerOps(linalg::GenericOp genericOp) {
assert(isArgmaxOp(genericOp) && "expected operation to be an argmax op");
ArgmaxCombinerOps ops;
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
// Extract max value producer: arith.maximumf.
Value maxResult = yieldOp.getOperand(0);
auto maxOp = dyn_cast<arith::MaximumFOp>(maxResult.getDefiningOp());
// Extract index result producer: arith.select.
Value indexResult = yieldOp.getOperand(1);
auto selectOp = dyn_cast<arith::SelectOp>(indexResult.getDefiningOp());
// Extract the condition of the select, expected to be arith.cmpf with
// predicate OGT.
auto cmpOp = dyn_cast<arith::CmpFOp>(selectOp.getCondition().getDefiningOp());
ops.maxOp = maxOp;
ops.selectOp = selectOp;
ops.cmpOp = cmpOp;
return ops;
}
static FailureOr<ArgmaxCombinerOps>
collectArgmaxCombinerOps(linalg::GenericOp genericOp) {
assert(isArgmaxOp(genericOp) && "expected operation to be an argmax op");
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
// Extract max value producer: arith.maximumf.
Value maxResult = yieldOp.getOperand(0);
auto maxOp = dyn_cast<arith::MaximumFOp>(maxResult.getDefiningOp());
// Extract index result producer: arith.select.
Value indexResult = yieldOp.getOperand(1);
auto selectOp = dyn_cast<arith::SelectOp>(indexResult.getDefiningOp());
// Extract the condition of the select, expected to be arith.cmpf with
// predicate OGT.
auto cmpOp = dyn_cast<arith::CmpFOp>(selectOp.getCondition().getDefiningOp());
ArgmaxCombinerOps ops;
ops.maxOp = maxOp;
ops.selectOp = selectOp;
ops.cmpOp = cmpOp;
return ops;
}

(I'd suggest aggregate initialization, e.g., return {.maxOp = maxOp, ...}, if IREE is built with c++20. Unfortunately, we are not using c++20 atm..)

Comment on lines +134 to +176
/// Example:
/// ```
/// // Original argmax op reducing over dim=512
/// %4:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["parallel", "reduction"]
/// } ins(%arg0 : tensor<?x512xbf16>)
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
/// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
/// %idx = linalg.index 1 : index
/// %cast = arith.index_cast %idx : index to i64
/// %max = arith.maximumf %in, %out : bf16
/// %cmp = arith.cmpf ogt, %in, %out : bf16
/// %sel = arith.select %cmp, %cast, %out_0 : i64
/// linalg.yield %max, %sel : bf16, i64
/// } -> (tensor<?xbf16>, tensor<?xi64>)
/// ```
/// To:
/// ```
/// // After splitting K=512 into 4 x 128
/// %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x512xbf16>
/// into tensor<?x4x128xbf16>
/// %init_val = linalg.fill ins(%cst : bf16) outs(%empty : tensor<?x4xbf16>)
/// %init_idx = linalg.fill ins(%zero : i64) outs(%empty : tensor<?x4xi64>)
/// %partial:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["parallel", "reduction"]
/// } ins(%expanded : tensor<?x4x128xbf16>)
/// outs(%init_val, %init_idx : tensor<?x4xbf16>, tensor<?x4xi64>) {
/// // compute global index: outer_idx * 128 + inner_idx
/// ...
/// }
///
/// // Final argmax over the tile dimension (dim=1 of ?x4)
/// %final:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["reduction"]
/// } ins(%partial#0, %partial#1)
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
/// // same combiner: maximumf, cmpf, select
/// ...
/// }
/// ```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional nit: remove the markdown style block. I've seen that people use it in IREE and MLIR, but I really don't get the point. I'm not a fan of using it because it adds visual noise to me. Marking it optional because there are no convention, IMO. We can just replace ``` with blank, and have indents for examples. I took a stab at the formatting a bit, please make the change if you think it looks better.

Suggested change
/// Example:
/// ```
/// // Original argmax op reducing over dim=512
/// %4:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["parallel", "reduction"]
/// } ins(%arg0 : tensor<?x512xbf16>)
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
/// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
/// %idx = linalg.index 1 : index
/// %cast = arith.index_cast %idx : index to i64
/// %max = arith.maximumf %in, %out : bf16
/// %cmp = arith.cmpf ogt, %in, %out : bf16
/// %sel = arith.select %cmp, %cast, %out_0 : i64
/// linalg.yield %max, %sel : bf16, i64
/// } -> (tensor<?xbf16>, tensor<?xi64>)
/// ```
/// To:
/// ```
/// // After splitting K=512 into 4 x 128
/// %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x512xbf16>
/// into tensor<?x4x128xbf16>
/// %init_val = linalg.fill ins(%cst : bf16) outs(%empty : tensor<?x4xbf16>)
/// %init_idx = linalg.fill ins(%zero : i64) outs(%empty : tensor<?x4xi64>)
/// %partial:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["parallel", "reduction"]
/// } ins(%expanded : tensor<?x4x128xbf16>)
/// outs(%init_val, %init_idx : tensor<?x4xbf16>, tensor<?x4xi64>) {
/// // compute global index: outer_idx * 128 + inner_idx
/// ...
/// }
///
/// // Final argmax over the tile dimension (dim=1 of ?x4)
/// %final:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["reduction"]
/// } ins(%partial#0, %partial#1)
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
/// // same combiner: maximumf, cmpf, select
/// ...
/// }
/// ```
/// Example: original argmax op reducing over dim=512
///
/// %4:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["parallel", "reduction"]
/// } ins(%arg0 : tensor<?x512xbf16>)
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
/// ^bb0(%in: bf16, %out: bf16, %out_0: i64):
/// %idx = linalg.index 1 : index
/// %cast = arith.index_cast %idx : index to i64
/// %max = arith.maximumf %in, %out : bf16
/// %cmp = arith.cmpf ogt, %in, %out : bf16
/// %sel = arith.select %cmp, %cast, %out_0 : i64
/// linalg.yield %max, %sel : bf16, i64
/// } -> (tensor<?xbf16>, tensor<?xi64>)
///
/// To: splitting K=512 into 4 x 128 + final argmax over the tile dimension
/// (dim=1 of ?x4)
///
/// %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<?x512xbf16>
/// into tensor<?x4x128xbf16>
/// %init_val = linalg.fill ins(%cst : bf16) outs(%empty : tensor<?x4xbf16>)
/// %init_idx = linalg.fill ins(%zero : i64) outs(%empty : tensor<?x4xi64>)
/// %partial:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["parallel", "reduction"]
/// } ins(%expanded : tensor<?x4x128xbf16>)
/// outs(%init_val, %init_idx : tensor<?x4xbf16>, tensor<?x4xi64>) {
/// // compute global index: outer_idx * 128 + inner_idx
/// ...
/// }
/// %final:2 = linalg.generic {
/// indexing_maps = [...],
/// iterator_types = ["reduction"]
/// } ins(%partial#0, %partial#1)
/// outs(%out_val, %out_idx : tensor<?xbf16>, tensor<?xi64>) {
/// // same combiner: maximumf, cmpf, select
/// ...
/// }

Comment on lines +475 to +483
// The total number of output elements along this new dimension is
// reductionDimSize / ratio.
int64_t outputDimsize = reductionDimSize / ratio;

if (reductionDimSize == ShapedType::kDynamic ||
reductionDimSize % ratio != 0) {
return rewriter.notifyMatchFailure(
genericOp, "Reduction dimension not divisible by split ratio");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 1: I think we should move the outputDimSize to the chunk that it is used, i.e., the place that you compute the output shapes/maps/etc. It is weird that we do the computation before the check. It could be dynamic value. In this case, we don't need to compute outputDimSize at all, right?

nit 2: switch to new preferred style, ShapedType::isDynamic(reductionDimSize)

picky nit for the error message: We typically start the first sentence with a lower-case letter, because it matches error message styles commonly produced by other tools

https://llvm.org/docs/CodingStandards.html#error-and-warning-messages

Also, to match error message styles commonly produced by other tools, start the first sentence with a lower-case letter, and finish the last sentence without a period, if it would end in one otherwise.

Comment on lines +505 to +610
SmallVector<AffineExpr> exprs;
SmallVector<ReassociationIndices> reassociation;
unsigned index = 0;
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
unsigned dim = map.getDimPosition(idx);
if (reductionDim == dim) {
if (control.innerParallel) {
newShape.push_back(ratio); // reduce
newShape.push_back(genericOp.getShape(operand)[idx] /
ratio); // parallel (insert)
exprs.push_back(rewriter.getAffineDimExpr(
dim < insertSplitDimension ? dim : dim + 1));
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
} else {
newShape.push_back(genericOp.getShape(operand)[idx] /
ratio); // parallel (insert)
newShape.push_back(ratio); // reduce
exprs.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
exprs.push_back(rewriter.getAffineDimExpr(
dim < insertSplitDimension ? dim : dim + 1));
}
reassociation.push_back({index++, index++});
continue;
}
newShape.push_back(genericOp.getShape(operand)[idx]);
exprs.push_back(rewriter.getAffineDimExpr(
dim < insertSplitDimension ? dim : dim + 1));
reassociation.push_back({index++});
}
newMaps.push_back(
AffineMap::get(map.getNumDims() + 1, 0, exprs, genericOp.getContext()));
// If the shape is unchanged the input doesn't change.
if (newShape == genericOp.getShape(operand)) {
newInputs.push_back(operand->get());
continue;
}
Type newType = RankedTensorType::get(
newShape,
cast<RankedTensorType>(operand->get().getType()).getElementType());

Value newInput = rewriter.create<tensor::ExpandShapeOp>(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
}

SmallVector<SmallVector<int64_t>> newOutputShapes;
SmallVector<AffineMap> outputMaps;
for (int i = 0; i < genericOp.getNumDpsInits(); ++i) {
OpOperand *output = genericOp.getDpsInitOperand(i);
AffineMap oldOutputMap = genericOp.getMatchingIndexingMap(output);
ArrayRef<int64_t> oldShape = genericOp.getShape(output);
SmallVector<int64_t> thisOutputShape;

SmallVector<AffineExpr> outputExpr;
for (unsigned idx = 0; idx <= oldShape.size(); ++idx) {
if (idx == insertSplitIndex) {
thisOutputShape.push_back(outputDimsize);
outputExpr.push_back(rewriter.getAffineDimExpr(insertSplitDimension));
}
if (idx < oldShape.size()) {
thisOutputShape.push_back(oldShape[idx]);
unsigned dim = oldOutputMap.getDimPosition(idx);
outputExpr.push_back(rewriter.getAffineDimExpr(
dim < insertSplitDimension ? dim : dim + 1));
}
}

AffineMap newOutputMap = AffineMap::get(oldOutputMap.getNumDims() + 1, 0,
outputExpr, rewriter.getContext());
newMaps.push_back(newOutputMap);
newOutputShapes.push_back(thisOutputShape);
}

// Handle dynamic dimensions for identity value tensor.
SmallVector<Value> dynValDims;
SmallVector<int64_t> newOutputShape = newOutputShapes[0];
for (size_t i = 0; i < newOutputShape.size(); ++i) {
if (ShapedType::isDynamic(newOutputShape[i])) {
dynValDims.push_back(rewriter.create<tensor::DimOp>(
loc, genericOp.getDpsInputOperand(0)->get(), i));
}
}

Type valueElemType = genericOp.getRegionOutputArgs()[0].getType();
Value emptyValTensor = rewriter.create<tensor::EmptyOp>(
loc, newOutputShape, valueElemType, dynValDims);
Value constantOp = rewriter.create<arith::ConstantOp>(loc, *identity);
Value identityVal =
rewriter.create<linalg::FillOp>(loc, constantOp, emptyValTensor)
.getResult(0);

// Handle dynamic dimensions for identity index tensor.
SmallVector<Value> dynIdxDims;
newOutputShape = newOutputShapes[1];
for (size_t i = 0; i < newOutputShape.size(); ++i) {
if (ShapedType::isDynamic(newOutputShape[i])) {
dynIdxDims.push_back(rewriter.create<tensor::DimOp>(
loc, genericOp.getDpsInputOperand(0)->get(), i));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this part is generating the expanded operands and the properties of the partial reduction op. Furthermore, it creates a "folded" linalg.fill op on the expanded domain.

I think the logic can be much simpler because we have few more patterns help a lot. First, you can declare a expandValue() method that applies the expand_shape for all the operands, and the partial reduction op takes them as inputs and outputs. Here, I think you'd have a question about how to fold the expand shape away.

  1. We have canonicalization patterns in FillOp that swaps the (fill, reshape) pair.
  2. We can explicitly populate the patterns that fold reshapes into empty ops. I don't know why upstream people do not put them EmptyOp canonicalization patterns, but they are valid to use IMO.
  3. For creating the expand_shape op, you don't need to compute the dynamic values yourself. You can use this builder that takes result type and reassociation map. The new type can be constructed easily, so I'm not going to dig into details. After you get the type, you can use getReassociationIndicesForReshape method to get the map, and use them to create the expand shape op. I checked that the upstream builder infers the output shape for you and they build the dynamic values for you.

This way, we can generate the expanded operands cleanly.

The next problem is about the generic op properties. For indexing maps, I think you can use argmaxOp.getIndexingMapsArray() to get all the indexing maps. It returns SmallVector<AffineMap> vector. Then you can declare a getExpandedIndexingMap() and iterate on all the indexing map to get the expanded indexing maps.

Then you get all the properties, except iterators. The iterator change is in below, so I'll leave the comment there.

I played a bit with the IR, I think you will generate something like below. (please ignore the transform script, it is just for kicking in the patterns that fold expand_shape away for empty ops.)

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0)>
module {
  util.func public @argmax(%arg0: tensor<?x131072xbf16>, %arg1: index) -> tensor<?xi64> {
    %cst = arith.constant 0xFF80 : bf16
    %c0_i64 = arith.constant 0 : i64
    %0 = tensor.empty(%arg1) : tensor<?xbf16>
    %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<?xbf16>) -> tensor<?xbf16>
    %2 = tensor.empty(%arg1) : tensor<?xi64>
    %3 = linalg.fill ins(%c0_i64 : i64) outs(%2 : tensor<?xi64>) -> tensor<?xi64>
    %c128 = arith.constant 128 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x131072xbf16>
    %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 1024, 128] : tensor<?x131072xbf16> into tensor<?x1024x128xbf16>
    %c1024 = arith.constant 1024 : index
    %4 = arith.divsi %dim, %c1024 : index
    %expanded_0 = tensor.expand_shape %1 [[0, 1]] output_shape [%dim, 1024] : tensor<?xbf16> into tensor<?x1024xbf16>
    %expanded_1 = tensor.expand_shape %3 [[0, 1]] output_shape [%dim, 1024] : tensor<?xi64> into tensor<?x1024xi64>
    %5:2 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded : tensor<?x1024x128xbf16>) outs(%expanded_0, %expanded_1 : tensor<?x1024xbf16>, tensor<?x1024xi64>) {
    ^bb0(%in: bf16, %out: bf16, %out_2: i64):
      %7 = linalg.index 1 : index
      %8 = linalg.index 2 : index
      %9 = arith.muli %7, %c128 : index
      %10 = arith.addi %9, %8 : index
      %11 = arith.index_cast %10 : index to i64
      %12 = arith.maximumf %in, %out : bf16
      %13 = arith.cmpf ogt, %in, %out : bf16
      %14 = arith.select %13, %11, %out_2 : i64
      linalg.yield %12, %14 : bf16, i64
    } -> (tensor<?x1024xbf16>, tensor<?x1024xi64>)
    %6:2 = linalg.generic {indexing_maps = [#map2, #map2, #map3, #map3], iterator_types = ["parallel", "reduction"]} ins(%5#0, %5#1 : tensor<?x1024xbf16>, tensor<?x1024xi64>) outs(%1, %3 : tensor<?xbf16>, tensor<?xi64>) {
    ^bb0(%in: bf16, %in_2: i64, %out: bf16, %out_3: i64):
      %7 = arith.maximumf %in, %out : bf16
      %8 = arith.cmpf ogt, %in, %out : bf16
      %9 = arith.select %8, %in_2, %out_3 : i64
      linalg.yield %7, %9 : bf16, i64
    } -> (tensor<?xbf16>, tensor<?xi64>)
    util.return %6#1 : tensor<?xi64>
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["util.func"]} in %arg0 : (!transform.any_op) -> !transform.op<"util.func">
      transform.apply_patterns to %0 {
        transform.apply_patterns.tensor.fold_tensor_empty
      } : !transform.op<"util.func">
      transform.yield
    }
  }
}

And I verified that the expand ops are folded away, if you run iree-opt --canonicalize -transform-interpreter ~/repro.mlir.

Does it make sense?

Comment on lines +619 to +628
SmallVector<utils::IteratorType> newIteratorTypes;
for (auto [index, iteratorType] :
llvm::enumerate(genericOp.getIteratorTypesArray())) {
if (insertSplitDimension == index)
newIteratorTypes.push_back(utils::IteratorType::parallel);
newIteratorTypes.push_back(iteratorType);
}
if (insertSplitDimension == genericOp.getIteratorTypesArray().size()) {
newIteratorTypes.push_back(utils::IteratorType::parallel);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use the insert method from SmallVector. Something like

SmallVector<utils::IteratorType> newIteratorTypes = genericOp.getIteratorTypesArray();
newIteratorTypes.insert(newIteratorTypes.begin() + index, utils::IteratorType::parallel);

unsigned intermRank = newOutputShape.size();
AffineMap valueMap = rewriter.getMultiDimIdentityMap(intermRank);
AffineMap indexMap = valueMap;
SmallVector<utils::IteratorType> reductionIteratorTypes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the construction of reductionIteratorTypes can be easier, and it makes the below loop simpler.

SmallVector<utils::IteratorType> reductionIteratorTypes(intermRank, utils::IteratorType::parallel);
reductionIteratorTypes[insertSplitIndex] = utils::IteratorType::reduction;

Then you dont need the else statement in the below loop.

Comment on lines +464 to +470
SmallVector<unsigned> dims;
genericOp.getReductionDims(dims);

unsigned reductionDim = dims[0];
if (control.innerParallel) {
insertSplitDimension = reductionDim + 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

innerParallel seems to be always false for argmax case, should we bail out when it is true?

Also, I'd remove the blank line in the between. They belong the same block to me.

// CHECK: %[[FILLIDX:.+]] = linalg.fill ins(%{{.+}} : i64) outs(%[[INITIDX]] : tensor<?x1024xi64>) -> tensor<?x1024xi64>

// CHECK: %[[PARTIAL:.+]]:2 = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed a comment I had in my mind, we may want to check indexing maps as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants