diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 278fedbbd3cb4..5b056fff362b6 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -382,4 +382,179 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Sparse Tensor Custom Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +def SparseTensor_LinalgIntersectOp : SparseTensor_Op<"linalg_intersect", [NoSideEffect, SameTypeOperands]> { + let summary = "Custom intersect operation within linalg.generic"; + let description = [{ + Custom intersect operation for use within `linalg.generic`. + The actual operation is held in a block for embedding by the sparse tensor dialect. + The final value in the block must be a sparse_tensor.linalg_yield. + + Example: + ```mlir + %result = sparse_tensor.linalg_intersect %a, %b : f64 to i8 { + ^bb0(%v0: f64, %v1: f64): + %cmp = arith.cmpf "oeq", %v0, %v1 : f64 + %ret_i8 = arith.extui %cmp : i1 to i8 + sparse_tensor.linalg_yield %ret_i8 : i8 + } + ``` + }]; + + let arguments = (ins AnyType:$a, AnyType:$b); + let results = (outs AnyType:$output); + let regions = (region SizedRegion<1>:$formula); + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` type($a) `to` type($output) $formula + }]; + let hasVerifier = 1; +} + +def SparseTensor_LinalgUnionOp : SparseTensor_Op<"linalg_union", [NoSideEffect, SameTypeOperands]> { + let summary = "Custom union operation within linalg.generic"; + let description = [{ + Custom union operation for use within `linalg.generic`. + The actual operation is held in a block for embedding by the sparse tensor dialect. + The final value in the block must be a sparse_tensor.linalg_yield. + + Example: + ```mlir + %result = sparse_tensor.linalg_union %a, %b: f64 to f64 { + ^bb0(%v0: f64, %v1: f64): + %cmp = arith.cmpf "olt", %v0, %v1 : f64 + %smaller = select %cmp, %v0, %v1 : f64 + sparse_tensor.linalg_yield %smaller : f64 + } + ``` + }]; + + let arguments = (ins AnyType:$a, AnyType:$b); + let results = (outs AnyType:$output); + let regions = (region SizedRegion<1>:$formula); + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` type($a) `to` type($output) $formula + }]; + let hasVerifier = 1; +} + +def SparseTensor_LinalgReduceOp : SparseTensor_Op<"linalg_reduce", [NoSideEffect, SameTypeOperands]> { + let summary = "Custom reduce operation within linalg.generic"; + let description = [{ + Custom reduce operation for use within `linalg.generic`. + The actual operation is held in a block for embedding by the sparse tensor dialect. + The final value in the formula block must be a sparse_tensor.linalg_yield. + + A separate "init" block contains the starting value for the reduction. + + Example: + ```mlir + %result = sparse_tensor.linalg_reduce %a, %b : f64 to f64 { + ^bb0(%v0: f64, %v1: f64): + %ret = arith.addf %v0, %v1 : f64 + sparse_tensor.linalg_yield %ret : f64 + } init { + %init = arith.constant 0.0 : f64 + sparse_tensor.linalg_yield %init : f64 + } + ``` + }]; + + let arguments = (ins AnyType:$a, AnyType:$b); + let results = (outs AnyType:$output); + let regions = (region SizedRegion<1>:$formula, SizedRegion<1>:$init); + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` type($a) `to` type($output) $formula `init` $init + }]; + let hasVerifier = 1; +} + +def SparseTensor_LinalgApplyOp : SparseTensor_Op<"linalg_apply", [NoSideEffect]> { + let summary = "Custom apply operation within linalg.generic"; + let description = [{ + Custom apply operation for use within `linalg.generic`. + The actual operation is held in a block for embedding by the sparse tensor dialect. + The final value in the formula block must be a sparse_tensor.linalg_yield. + + Example: + ``` + %result = sparse_tensor.linalg_apply %a : f64 to f64 { + ^bb0(%v0: f64): + %cf1 = arith.constant 1.0 : f64 + %ret = arith.addf %v0, %cf1 : f64 + sparse_tensor.linalg_yield %ret : f64 + } + ``` + }]; + + let arguments = (ins AnyType:$a); + let results = (outs AnyType:$output); + let regions = (region SizedRegion<1>:$formula); + + let assemblyFormat = [{ + $a attr-dict `:` type($a) `to` type($output) $formula + }]; + let hasVerifier = 1; +} + +def SparseTensor_LinalgMaskOp : SparseTensor_Op<"linalg_mask", []> { + let summary = "Mask operation within linalg.generic"; + let description = [{ + Custom mask operation for use within `linalg.generic`. + The mask operation must be the first operation in the linalg.generic block. + + The final value in the mask block must be a sparse_tensor.linalg_yield + and must return a boolean (i1) result. If the return value is false, + the output will be masked (i.e. empty). + The meaning of each block argument depends on the rank of the linalg.generic output tensor. + + Rank 1 (Vector) + - arg0 : index + - arg1 : value + + Rank 2 (Matrix) + - arg0 : row + - arg1 : column + - arg2 : value + + Example: + ``` + sparse_tensor.linalg_mask { + ^bb0(%row: index, col%: index): + %triu = arith.cmpi "ugt", %col, %row : index + sparse_tensor.linalg_yield %triu : i1 + } + ``` + }]; + + let regions = (region SizedRegion<1>:$expr); + + let assemblyFormat = [{ + $expr attr-dict + }]; + let hasVerifier = 1; +} + +def SparseTensor_LinalgYieldOp : SparseTensor_Op<"linalg_yield", [Terminator]> { + let summary = "Yield from sparse_tensor.linalg_* methods"; + let description = [{ + Yield a value from within a custom block + + Example: + ``` + sparse_tensor.linalg_yield %result : f64 + ``` + }]; + + let arguments = (ins AnyType:$result); + let assemblyFormat = [{ + $result attr-dict `:` type($result) + }]; +} + #endif // SPARSETENSOR_OPS diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 304ba93737b5b..924b0198363e5 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -44,6 +44,8 @@ enum Kind { kCastU, // unsigned kTruncI, kBitCast, + // Custom unary + kApply, // Binary operations. kMulF, kMulI, @@ -60,6 +62,10 @@ enum Kind { kShrS, // signed kShrU, // unsigned kShlI, + // Custom binary + kIntersect, + kUnion, + kReduce, }; /// Children subexpressions of tensor operations. @@ -70,7 +76,7 @@ struct Children { /// Tensor expression. Represents a MLIR expression in tensor index notation. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v); + TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *op); /// Tensor expression kind. Kind kind; @@ -87,6 +93,8 @@ struct TensorExp { /// infer destination type) of a cast operation During code generation, /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; + + Operation *operation; }; /// Lattice point. Each lattice point consists of a conjunction of tensor @@ -125,9 +133,9 @@ class Merger { hasSparseOut(false), dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); - unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); } - unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), Operation *op = nullptr); + unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) { return addExp(k, e, -1u, v, op); } + unsigned addExp(Kind k, Value v, Operation *op = nullptr) { return addExp(k, -1u, -1u, v, op); } /// Adds an iteration lattice point. Returns its index. unsigned addLat(unsigned t, unsigned i, unsigned e); @@ -139,20 +147,20 @@ class Merger { /// of loop indices (effectively constructing a larger "intersection" of those /// indices) with a newly constructed tensor (sub)expression of given kind. /// Returns the index of the new lattice point. - unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1); + unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op = nullptr); /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of /// cartesian product. Returns the index of the new set. - unsigned takeConj(Kind kind, unsigned s0, unsigned s1); + unsigned takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op = nullptr); /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). /// Returns the index of the new set. - unsigned takeDisj(Kind kind, unsigned s0, unsigned s1); + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op = nullptr); /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E /// as new expression. Returns the index of the new set. - unsigned mapSet(Kind kind, unsigned s0, Value v = Value()); + unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), Operation *op = nullptr); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid @@ -225,13 +233,18 @@ class Merger { /// Returns index of the root expression. unsigned buildLattices(unsigned e, unsigned i); + /// Returns the identity value (i.e. x op identity == x) + /// This value is used in reductions as the initial value, meant to have + /// no impact on the final reduction value. + Value getIdentity(PatternRewriter &rewriter, Location loc, unsigned e, Type tp); + /// Builds a tensor expression from the given Linalg operation. /// Returns index of the root expression on success. Optional buildTensorExpFromLinalg(linalg::GenericOp op); /// Rebuilds SSA format from a tensor expression. Value buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0, - Value v1); + Value v1, std::vector idxs); private: /// Private helpers. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index ecbc989a2c141..0038032993339 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -334,6 +334,115 @@ LogicalResult OutOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Sparse Tensor Custom Linalg.Generic Operations. +//===----------------------------------------------------------------------===// + +LogicalResult LinalgIntersectOp::verify() { + Region ®ion = formula(); + Block &formula = region.front(); + if (formula.getNumArguments() != 2) + return emitError("block must have 2 arguments"); + + Type outputType = output().getType(); + LinalgYieldOp yield = + llvm::dyn_cast_or_null(formula.getTerminator()); + if (yield == nullptr) + return emitError("intersect block must end with sparse_tensor.linalg_yield"); + Value retVal = yield.result(); + if (retVal.getType() != outputType) + return emitError("yield value in block does not match intersect return type"); + + return success(); +} + +LogicalResult LinalgUnionOp::verify() { + Region ®ion = formula(); + Block &formula = region.front(); + if (formula.getNumArguments() != 2) + return emitError("block must have 2 arguments"); + + Type outputType = output().getType(); + LinalgYieldOp yield = + llvm::dyn_cast_or_null(formula.getTerminator()); + if (yield == nullptr) + return emitError("union block must end with sparse_tensor.linalg_yield"); + Value retVal = yield.result(); + if (retVal.getType() != outputType) + return emitError("yield value in block does not match union return type"); + + return success(); +} + +LogicalResult LinalgReduceOp::verify() { + Region &formulaRegion = formula(); + Block &formula = formulaRegion.front(); + if (formula.getNumArguments() != 2) + return emitError("formula block must have 2 arguments"); + + Type outputType = output().getType(); + LinalgYieldOp yield = + llvm::dyn_cast_or_null(formula.getTerminator()); + if (yield == nullptr) + return emitError("reduce block must end with sparse_tensor.linalg_yield"); + Value retVal = yield.result(); + if (retVal.getType() != outputType) + return emitError("yield value in formula block does not match reduce return type"); + + Region &initRegion = init(); + Block &init = initRegion.front(); + if (init.getNumArguments() != 0) + return emitError("init block must not have any arguments"); + + LinalgYieldOp initYield = + llvm::dyn_cast_or_null(init.getTerminator()); + if (initYield == nullptr) + return emitError("init block must end with sparse_tensor.linalg_yield"); + Value initVal = initYield.result(); + if (initVal.getType() != outputType) + return emitError("yield value in init block does not match reduce return type"); + + return success(); +} + +LogicalResult LinalgApplyOp::verify() { + Region ®ion = formula(); + Block &formula = region.front(); + if (formula.getNumArguments() < 1) + return emitError("block must have at least 1 argument"); + if (formula.getNumArguments() > 3) + return emitError("block must have no more than 3 arguments"); + + Type outputType = output().getType(); + LinalgYieldOp yield = + llvm::dyn_cast_or_null(formula.getTerminator()); + if (yield == nullptr) + return emitError("apply block must end with sparse_tensor.linalg_yield"); + + Value retVal = yield.result(); + if (retVal.getType() != outputType) + return emitError("yield value in block does not match apply return type"); + + return success(); +} + +LogicalResult LinalgMaskOp::verify() { + // Result of block must be i1 + Region ®ion = expr(); + Block &block = region.front(); + LinalgYieldOp yield = + llvm::dyn_cast_or_null(block.getTerminator()); + if (yield == nullptr) + return emitError("mask block must end with sparse_tensor.linalg_yield"); + + Type retType = yield.result().getType(); + IntegerType iType = retType.dyn_cast(); + if (!iType || iType.getWidth() != 1) + return emitError("mask block must return i1 type"); + + return success(); +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 72e70ddbc123e..452531b5ec543 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -42,7 +42,7 @@ namespace { enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; // Reduction kinds. -enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; +enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; // Code generation. struct CodeGen { @@ -53,10 +53,10 @@ struct CodeGen { indices(numTensors, std::vector(numLoops)), highs(numTensors, std::vector(numLoops)), pidxs(numTensors, std::vector(numLoops)), - idxs(numTensors, std::vector(numLoops)), redExp(-1u), redVal(), + idxs(numTensors, std::vector(numLoops)), redExp(-1u), redVal(), redValidLexInsert(), redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(), expCount(), curVecLength(1), - curVecMask() {} + curVecMask(), rank(), maskLoop(), maskYield() {} /// Sparsification options. SparsificationOptions options; /// Universal dense indices and upper bounds (by index). The loops array @@ -80,6 +80,7 @@ struct CodeGen { /// reduction are exhausted, all inner loops can use a scalarized reduction. unsigned redExp; Value redVal; + Value redValidLexInsert; Reduction redKind; // Sparse tensor as output. Implemented either through direct injective // insertion in lexicographic index order (where indices are updated @@ -95,6 +96,10 @@ struct CodeGen { // Current vector length and mask. unsigned curVecLength; Value curVecMask; + // Linalg mask variables + unsigned rank; + scf::IfOp maskLoop; + LinalgYieldOp maskYield; }; } // namespace @@ -369,6 +374,8 @@ static StringRef getReductionName(Reduction kind) { return "or"; case kXor: return "xor"; + case kCustom: + return "custom"; } llvm_unreachable("unknown reduction kind"); } @@ -390,6 +397,8 @@ static Reduction getReduction(Kind kind) { return kOr; case Kind::kXorI: return kXor; + case Kind::kReduce: + return kCustom; default: llvm_unreachable("unexpected reduction operator"); } @@ -404,6 +413,7 @@ static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, Value r = codegen.redVal; switch (codegen.redKind) { case kNoReduc: + case kCustom: break; case kSum: case kXor: @@ -434,9 +444,10 @@ static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, } /// Updates scalarized reduction value. -static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { +static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc, Value validLexInsert = Value()) { assert(codegen.redKind != kNoReduc); codegen.redVal = merger.exp(codegen.redExp).val = reduc; + codegen.redValidLexInsert = validLexInsert; } //===----------------------------------------------------------------------===// @@ -542,6 +553,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, auto dynShape = {ShapedType::kDynamicSize}; auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); codegen.lexIdx = rewriter.create(loc, memTp, rank); + codegen.rank = op.getRank(t); } else { // Annotated sparse tensors. auto dynShape = {ShapedType::kDynamicSize}; @@ -702,26 +714,51 @@ static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, } /// Generates insertion code to implement dynamic tensor load. -static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, +static Value genInsertionLoad(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, OpOperand *t) { Location loc = op.getLoc(); + Type tp = getElementTypeOrSelf(t->get().getType()); // Direct lexicographic index order, tensor loads as zero. if (!codegen.expValues) { - Type tp = getElementTypeOrSelf(t->get().getType()); - return constantZero(rewriter, loc, tp); + if (codegen.redKind == kNoReduc) + return constantZero(rewriter, loc, tp); + else + return merger.getIdentity(rewriter, loc, codegen.redExp, tp); } // Load from expanded access pattern. Value index = genIndex(codegen, op, t); - return rewriter.create(loc, codegen.expValues, index); + if (codegen.redKind == kNoReduc) + return rewriter.create(loc, codegen.expValues, index); + // Value may be filled; if not, use the reduction init value + Value isFilled = rewriter.create(loc, codegen.expFilled, index); + scf::IfOp if_isFilled = rewriter.create(loc, tp, isFilled, /*else=*/true); + // True branch + rewriter.setInsertionPointToStart(if_isFilled.thenBlock()); + Value valAtIndex = rewriter.create(loc, codegen.expValues, index); + rewriter.create(loc, valAtIndex); + // False branch + rewriter.setInsertionPointToStart(if_isFilled.elseBlock()); + Value initVal = merger.getIdentity(rewriter, loc, codegen.redExp, tp); + rewriter.create(loc, initVal); + rewriter.setInsertionPointAfter(if_isFilled); + // End if + return if_isFilled.getResult(0); } /// Generates insertion code to implement dynamic tensor store. static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op, OpOperand *t, Value rhs) { + linalg::GenericOp op, OpOperand *t, Value rhs, Value validLexInsert) { Location loc = op.getLoc(); // Direct insertion in lexicographic index order. if (!codegen.expValues) { - rewriter.create(loc, t->get(), codegen.lexIdx, rhs); + if (!validLexInsert) + rewriter.create(loc, t->get(), codegen.lexIdx, rhs); + else { + scf::IfOp ifValidLexInsert = rewriter.create(loc, validLexInsert); + rewriter.setInsertionPointToStart(ifValidLexInsert.thenBlock()); + rewriter.create(loc, t->get(), codegen.lexIdx, rhs); + rewriter.setInsertionPointAfter(ifValidLexInsert); + } return; } // Generates insertion code along expanded access pattern. @@ -770,7 +807,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, // Load during insertion. OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; if (t == codegen.sparseOut) - return genInsertionLoad(codegen, rewriter, op, t); + return genInsertionLoad(merger, codegen, rewriter, op, t); // Actual load. SmallVector args; Value ptr = genSubscript(codegen, rewriter, op, t, args); @@ -782,20 +819,20 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, /// Generates a store on a dense or sparse tensor. static void genTensorStore(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - Value rhs) { + Value rhs, Value validLexInsert) { Location loc = op.getLoc(); // Test if this is a scalarized reduction. if (codegen.redVal) { if (codegen.curVecLength > 1) rhs = rewriter.create(loc, codegen.curVecMask, rhs, codegen.redVal); - updateReduc(merger, codegen, rhs); + updateReduc(merger, codegen, rhs, validLexInsert); return; } // Store during insertion. OpOperand *t = op.getOutputOperand(0); if (t == codegen.sparseOut) { - genInsertionStore(codegen, rewriter, op, t, rhs); + genInsertionStore(codegen, rewriter, op, t, rhs, validLexInsert); return; } // Actual store. @@ -874,19 +911,59 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, return rewriter.create(loc, mul, i); } +/// Generates the mask if-statement. +static void genMaskLoop(CodeGen &codegen, PatternRewriter &rewriter, + LinalgYieldOp op, SmallVector injectArgs) { + Location loc = op.getLoc(); + Operation *placeholder = rewriter.create(loc, 0); + rewriter.mergeBlockBefore(op->getBlock(), placeholder, injectArgs); + Value maskCmp = op.result(); + rewriter.eraseOp(placeholder); + rewriter.eraseOp(op); + + if (codegen.expValues) { + TypeRange forRetType = maskCmp.getDefiningOp()->getParentOp()->getResultTypes(); + codegen.maskLoop = rewriter.create(loc, forRetType, maskCmp, true); + } else { + codegen.maskLoop = rewriter.create(loc, maskCmp); + } + rewriter.setInsertionPointToStart(codegen.maskLoop.thenBlock()); +} + /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op, unsigned exp) { + linalg::GenericOp op, unsigned exp, unsigned last = 0) { Location loc = op.getLoc(); if (exp == -1u) return Value(); - if (merger.exp(exp).kind == Kind::kTensor) - return genTensorLoad(merger, codegen, rewriter, op, exp); + if (merger.exp(exp).kind == Kind::kTensor) { + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; + OpOperand *lhs = op.getOutputOperand(0); + if (lhs == t) { + codegen.redKind = getReduction(merger.exp(last).kind); + codegen.redExp = last; + } + Value redVal = genTensorLoad(merger, codegen, rewriter, op, exp); + if (lhs == t) + codegen.redExp = exp; + return redVal; + } if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); - Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); - Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); - return merger.buildExp(rewriter, loc, exp, v0, v1); + Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, exp); + Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, exp); + + if (codegen.maskYield) { + // Masking based on value must be handled here + SmallVector injectArgs; + injectArgs.push_back(codegen.loops[0]); + if (codegen.rank == 2) + injectArgs.push_back(codegen.loops[1]); + injectArgs.push_back(v0); + genMaskLoop(codegen, rewriter, codegen.maskYield, injectArgs); + } + + return merger.buildExp(rewriter, loc, exp, v0, v1, codegen.loops); } /// Determines if affine expression is invariant. @@ -914,7 +991,7 @@ static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, static void genInvariants(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp, unsigned ldx, bool atStart, - Kind last = Kind::kTensor) { + unsigned last = 0) { if (exp == -1u) return; if (merger.exp(exp).kind == Kind::kTensor) { @@ -935,16 +1012,19 @@ static void genInvariants(Merger &merger, CodeGen &codegen, if (lhs == t) { // Start or end a scalarized reduction if (atStart) { + codegen.redKind = getReduction(merger.exp(last).kind); + codegen.redExp = last; // this allows for custom reduction initialization Value load = genTensorLoad(merger, codegen, rewriter, op, exp); - codegen.redKind = getReduction(last); - codegen.redExp = exp; + codegen.redExp = exp; // this ties the initial reduction to the output tensor updateReduc(merger, codegen, load); + } else { Value redVal = codegen.redVal; - updateReduc(merger, codegen, Value()); + Value redValidLexInsert = codegen.redValidLexInsert; + updateReduc(merger, codegen, Value(), Value()); codegen.redExp = -1u; codegen.redKind = kNoReduc; - genTensorStore(merger, codegen, rewriter, op, redVal); + genTensorStore(merger, codegen, rewriter, op, redVal, redValidLexInsert); } } else { // Start or end loop invariant hoisting of a tensor load. @@ -955,11 +1035,10 @@ static void genInvariants(Merger &merger, CodeGen &codegen, // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - Kind last = merger.exp(exp).kind; unsigned e0 = merger.exp(exp).children.e0; unsigned e1 = merger.exp(exp).children.e1; - genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); - genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); + genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, exp); + genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, exp); } } @@ -1179,6 +1258,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each index. + Location loc = op.getLoc(); Type indexType = rewriter.getIndexType(); for (unsigned b = 0, be = indices.size(); b < be; b++) { if (indices[b] && merger.isDim(b, Dim::kSparse)) { @@ -1191,6 +1271,12 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, if (codegen.redVal) { types.push_back(codegen.redVal.getType()); operands.push_back(codegen.redVal); + if (codegen.sparseOut) { + Type boolType = rewriter.getIntegerType(1); + Value falseval = rewriter.create(loc, 0, boolType); + types.push_back(boolType); + operands.push_back(falseval); + } } if (codegen.expValues) { types.push_back(indexType); @@ -1201,7 +1287,6 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, operands.push_back(codegen.loops[idx]); } assert(types.size() == operands.size()); - Location loc = op.getLoc(); scf::WhileOp whileOp = rewriter.create(loc, types, operands); SmallVector locs(types.size(), loc); @@ -1225,8 +1310,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, codegen.pidxs[tensor][idx] = after->getArgument(o++); } } - if (codegen.redVal) - updateReduc(merger, codegen, after->getArgument(o++)); + if (codegen.redVal) { + if (codegen.sparseOut) { + BlockArgument valArg = after->getArgument(o++); + BlockArgument validLexArg = after->getArgument(o++); + updateReduc(merger, codegen, valArg, validLexArg); + } else { + updateReduc(merger, codegen, after->getArgument(o++)); + } + } if (codegen.expValues) codegen.expCount = after->getArgument(o++); if (needsUniv) @@ -1311,10 +1403,39 @@ static void genLocals(Merger &merger, CodeGen &codegen, // Move the insertion indices in lexicographic index order. During access // pattern expansion, we can skip setting the innermost dimension. - if (codegen.sparseOut && !codegen.expValues) { - Value pos = constantIndex(rewriter, loc, at); - rewriter.create(loc, codegen.loops[idx], codegen.lexIdx, - pos); + if (codegen.sparseOut) { + if (idx == codegen.rank - 1) { + // Handle linalg_mask + Block &block = op.region().front(); + Operation &firstOp = block.front(); + LinalgMaskOp maskOp = dyn_cast_or_null(&firstOp); + if (maskOp != nullptr) { + assert(codegen.rank <= 2 && "mask only supported for tensors of rank 1 or 2"); + Region &maskRegion = maskOp.expr(); + Block &maskBlock = maskRegion.front(); + LinalgYieldOp yield = llvm::dyn_cast_or_null(maskBlock.getTerminator()); + + unsigned numArgs = maskBlock.getNumArguments(); + assert(numArgs > 0 && "mask block must have at least one argument"); + if (numArgs <= codegen.rank) { + SmallVector injectArgs; + injectArgs.push_back(codegen.loops[0]); + if (codegen.rank == 2) + injectArgs.push_back(codegen.loops[1]); + + genMaskLoop(codegen, rewriter, yield, injectArgs); + } else { + // Save for when value is available + codegen.maskYield = yield; + } + } + } + + if (!codegen.expValues) { + Value pos = constantIndex(rewriter, loc, at); + rewriter.create(loc, codegen.loops[idx], codegen.lexIdx, + pos); + } } } @@ -1333,7 +1454,13 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen, SmallVector yields; if (codegen.redVal) { yields.push_back(codegen.redVal); - updateReduc(merger, codegen, ifOp.getResult(y++)); + Value valArg = ifOp.getResult(y++); + if (codegen.redValidLexInsert) { + yields.push_back(codegen.redValidLexInsert); + Value validLexArg = ifOp.getResult(y++); + updateReduc(merger, codegen, valArg, validLexArg); + } else + updateReduc(merger, codegen, valArg); } if (codegen.expValues) { yields.push_back(codegen.expCount); @@ -1369,7 +1496,14 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen, } if (codegen.redVal) { operands.push_back(codegen.redVal); - updateReduc(merger, codegen, whileOp->getResult(o++)); + OpResult valArg = whileOp->getResult(o++); + if (codegen.redValidLexInsert) { + operands.push_back(codegen.redValidLexInsert); + OpResult validLexArg = whileOp->getResult(o++); + updateReduc(merger, codegen, valArg, validLexArg); + } else { + updateReduc(merger, codegen, valArg); + } } if (codegen.expValues) { operands.push_back(codegen.expCount); @@ -1394,7 +1528,13 @@ static void genForInduction(Merger &merger, CodeGen &codegen, SmallVector operands; if (codegen.redVal) { operands.push_back(codegen.redVal); - updateReduc(merger, codegen, loop->getResult(o++)); + OpResult valArg = loop->getResult(o++); + if (codegen.redValidLexInsert) { + operands.push_back(codegen.redValidLexInsert); + OpResult validLexArg = loop->getResult(o++); + updateReduc(merger, codegen, valArg, validLexArg); + } else + updateReduc(merger, codegen, valArg); } if (codegen.expValues) { operands.push_back(codegen.expCount); @@ -1403,6 +1543,19 @@ static void genForInduction(Merger &merger, CodeGen &codegen, assert(o == operands.size()); if (o > 0) rewriter.create(loc, operands); + + if (codegen.maskLoop && codegen.maskLoop->getParentOp() == loop) { + if (o > 0) { + rewriter.setInsertionPointToStart(codegen.maskLoop.elseBlock()); + scf::ForOp forLoop = dyn_cast_or_null(loop); + rewriter.create(loc, forLoop.getLoopBody().getArgument(1)); + + rewriter.setInsertionPointAfter(codegen.maskLoop); + rewriter.create(loc, codegen.maskLoop.getResults()); + } + rewriter.setInsertionPointAfter(codegen.maskLoop); + } + rewriter.setInsertionPointAfter(loop); } @@ -1429,8 +1582,11 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, cond = cond ? rewriter.create(loc, cond, clause) : clause; } } - if (codegen.redVal) + if (codegen.redVal) { types.push_back(codegen.redVal.getType()); + if (codegen.sparseOut) + types.push_back(codegen.redValidLexInsert.getType()); + } if (codegen.expValues) types.push_back(rewriter.getIndexType()); scf::IfOp ifOp = rewriter.create(loc, types, cond, /*else=*/true); @@ -1441,11 +1597,13 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, /// Generates end of true branch of if-statement within a while-loop. static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, - Value redInput, Value cntInput) { + Value redInput, Value redValidLexInsert, Value cntInput) { SmallVector operands; if (codegen.redVal) { operands.push_back(codegen.redVal); - updateReduc(merger, codegen, redInput); + if (codegen.redValidLexInsert) + operands.push_back(codegen.redValidLexInsert); + updateReduc(merger, codegen, redInput, redValidLexInsert); } if (codegen.expValues) { operands.push_back(codegen.expCount); @@ -1547,7 +1705,15 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, // At each leaf, assign remaining tensor (sub)expression to output tensor. if (at == topSort.size()) { Value rhs = genExp(merger, codegen, rewriter, op, exp); - genTensorStore(merger, codegen, rewriter, op, rhs); + Value validLexInsert; + if (codegen.redValidLexInsert) { + Location loc = op.getLoc(); + Type boolType = rewriter.getIntegerType(1); + validLexInsert = rewriter.create(loc, 1, boolType); + } else { + validLexInsert = Value(); + } + genTensorStore(merger, codegen, rewriter, op, rhs, validLexInsert); return; } @@ -1571,6 +1737,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. Value redInput = codegen.redVal; + Value redValidLexInsert = codegen.redValidLexInsert; Value cntInput = codegen.expCount; bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { @@ -1582,7 +1749,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, scf::IfOp ifOp = genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); - endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); + endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, redValidLexInsert, cntInput); } else { genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 31e7fb5a07edd..0bbc6d3c1c645 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -19,8 +20,8 @@ namespace sparse_tensor { // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), val(v) { +TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *op) + : kind(k), val(v), operation(op) { switch (kind) { case kTensor: assert(x != -1u && y == -1u && !v); @@ -52,6 +53,18 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) children.e0 = x; children.e1 = y; break; + case kApply: + assert(x != -1u && y == -1u && op); + children.e0 = x; + children.e1 = y; + break; + case kIntersect: + case kUnion: + case kReduce: + assert(x != -1u && y != -1u && op); + children.e0 = x; + children.e1 = y; + break; default: assert(x != -1u && y != -1u && !v); children.e0 = x; @@ -72,9 +85,9 @@ LatPoint::LatPoint(const BitVector &b, unsigned e) // Lattice methods. //===----------------------------------------------------------------------===// -unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { +unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, Operation *op) { unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1, v)); + tensorExps.push_back(TensorExp(k, e0, e1, v, op)); return e; } @@ -91,25 +104,25 @@ unsigned Merger::addSet() { return s; } -unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { +unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op) { unsigned p = latPoints.size(); BitVector nb = BitVector(latPoints[p0].bits); nb |= latPoints[p1].bits; - unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); + unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); latPoints.push_back(LatPoint(nb, e)); return p; } -unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { +unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { unsigned s = addSet(); for (unsigned p0 : latSets[s0]) for (unsigned p1 : latSets[s1]) - latSets[s].push_back(conjLatPoint(kind, p0, p1)); + latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); return s; } -unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { - unsigned s = takeConj(kind, s0, s1); +unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { + unsigned s = takeConj(kind, s0, s1, op); // Followed by all in s0. for (unsigned p : latSets[s0]) latSets[s].push_back(p); @@ -124,11 +137,11 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { return s; } -unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { - assert(kAbsF <= kind && kind <= kBitCast); +unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { + assert(kAbsF <= kind && kind <= kApply); unsigned s = addSet(); for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, latPoints[p].exp, v); + unsigned e = addExp(kind, latPoints[p].exp, v, op); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } @@ -324,6 +337,14 @@ static const char *kindToOpSymbol(Kind kind) { return ">>"; case kShlI: return "<<"; + case kApply: + return "CustomLinalgApply"; + case kIntersect: + return "CustomLinalgIntersect"; + case kUnion: + return "CustomLinalgUnion"; + case kReduce: + return "CustomLinalgReduce"; } llvm_unreachable("unexpected kind for symbol"); } @@ -515,10 +536,62 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) { return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); + case kApply: + return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), + tensorExps[e].val, tensorExps[e].operation); + case kIntersect: + // Custom conjunction + return takeConj(kind, // take binary conjunction + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), + tensorExps[e].operation); + case kUnion: + // Custom disjunction + return takeDisj(kind, // take binary disjunction + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), + tensorExps[e].operation); + case kReduce: + return takeConj(kind, + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), + tensorExps[e].operation); } llvm_unreachable("unexpected expression kind"); } +Value Merger::getIdentity(PatternRewriter &rewriter, Location loc, unsigned e, Type tp) { + Kind kind = tensorExps[e].kind; + switch (kind) { + case kAddF: + case kAddI: + case kOrI: + return rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + case kMulF: + return rewriter.create(loc, tp, rewriter.getFloatAttr(tp, 1.0)); + case kMulI: + case kAndI: + return rewriter.create(loc, tp, rewriter.getIntegerAttr(tp, 1)); + case kReduce: + { + // Insert identity from init block + Operation *origOp = tensorExps[e].operation; + sparse_tensor::LinalgReduceOp laop = llvm::dyn_cast_or_null(origOp); + Region ®ion = laop.init(); + Block &formula = region.front(); + LinalgYieldOp yield = dyn_cast_or_null(formula.getTerminator()); + Operation *placeholder = rewriter.create(loc, 0); + rewriter.mergeBlockBefore(&formula, placeholder, {}); + Value retVal = yield.result(); + rewriter.eraseOp(placeholder); + rewriter.eraseOp(yield); + return retVal; + } + default: + return rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + } +} + Optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { Operation *yield = op.region().front().getTerminator(); return buildTensorExp(op, yield->getOperand(0)); @@ -602,6 +675,13 @@ Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { return addExp(kTruncI, e, v); if (isa(def)) return addExp(kBitCast, e, v); + if (isa(def)) { + sparse_tensor::LinalgApplyOp laop = v.getDefiningOp(); + Region ®ion = laop.formula(); + Block &formula = region.front(); + Operation &term = formula.back(); + return addExp(kApply, e, Value(), &term); + } } } // Construct binary operations if subexpressions can be built. @@ -643,6 +723,22 @@ Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) return addExp(kShlI, e0, e1); + if (isa(def)) { + sparse_tensor::LinalgIntersectOp laop = v.getDefiningOp(); + Region ®ion = laop.formula(); + Block &formula = region.front(); + Operation &term = formula.back(); + return addExp(kIntersect, e0, e1, Value(), &term); + } + if (isa(def)) { + sparse_tensor::LinalgUnionOp laop = v.getDefiningOp(); + Region ®ion = laop.formula(); + Block &formula = region.front(); + Operation &term = formula.back(); + return addExp(kUnion, e0, e1, Value(), &term); + } + if (isa(def)) + return addExp(kReduce, e0, e1, Value(), def); } } // Cannot build. @@ -650,7 +746,7 @@ Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { } Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, - Value v0, Value v1) { + Value v0, Value v1, std::vector idxs) { switch (tensorExps[e].kind) { case kTensor: case kInvariant: @@ -721,6 +817,51 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); + case kApply: + { + LinalgYieldOp yield = dyn_cast_or_null(tensorExps[e].operation); + Operation *placeholder = rewriter.create(loc, 0); + Block *block = yield->getBlock(); + SmallVector injectArgs; + injectArgs.push_back(v0); + if (block->getNumArguments() >= 2) + injectArgs.push_back(idxs[0]); + if (block->getNumArguments() >= 3) + injectArgs.push_back(idxs[1]); + rewriter.mergeBlockBefore(block, placeholder, injectArgs); + Value retVal = yield.result(); + rewriter.eraseOp(placeholder); + rewriter.eraseOp(yield); + return retVal; + } + case kIntersect: + case kUnion: + { + // Merge the formula block into the loop + LinalgYieldOp yield = dyn_cast_or_null(tensorExps[e].operation); + Operation *placeholder = rewriter.create(loc, 0); + rewriter.mergeBlockBefore(yield->getBlock(), placeholder, {v0, v1}); + Value retVal = yield.result(); + rewriter.eraseOp(placeholder); + rewriter.eraseOp(yield); + return retVal; + } + case kReduce: + { + // Merge the formula block into the loop + Operation *origOp = tensorExps[e].operation; + sparse_tensor::LinalgReduceOp laop = llvm::dyn_cast_or_null(origOp); + Region ®ion = laop.formula(); + Block &formula = region.front(); + LinalgYieldOp yield = dyn_cast_or_null(formula.getTerminator()); + Operation *placeholder = rewriter.create(loc, 0); + rewriter.mergeBlockBefore(&formula, placeholder, {v0, v1}); + Value retVal = yield.result(); + rewriter.eraseOp(placeholder); + rewriter.eraseOp(yield); + return retVal; + } + } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir index 7d8461ce2e167..42b0e137b2296 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -62,6 +62,7 @@ func @matmul1(%a: tensor<10x20xf32, #DCSR>, // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_199:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[VAL_6:.*]] = arith.constant false // CHECK-DAG: %[[VAL_7:.*]] = arith.constant true // CHECK: %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_2]], %[[VAL_2]]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> @@ -108,10 +109,16 @@ func @matmul1(%a: tensor<10x20xf32, #DCSR>, // CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_53]]] : memref // CHECK: %[[VAL_55:.*]] = scf.for %[[VAL_56:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_4]] iter_args(%[[VAL_57:.*]] = %[[VAL_42]]) -> (index) { // CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_56]]] : memref -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_200:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_201:.*]] = scf.if %[[VAL_200]] -> (f64) { +// CHECK: %[[VAL_202:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref +// CHECK: scf.yield %[[VAL_202]] : f64 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_199]] : f64 +// CHECK: } // CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_56]]] : memref // CHECK: %[[VAL_61:.*]] = arith.mulf %[[VAL_51]], %[[VAL_60]] : f64 -// CHECK: %[[VAL_62:.*]] = arith.addf %[[VAL_59]], %[[VAL_61]] : f64 +// CHECK: %[[VAL_62:.*]] = arith.addf %[[VAL_201]], %[[VAL_61]] : f64 // CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref // CHECK: %[[VAL_64:.*]] = arith.cmpi eq, %[[VAL_63]], %[[VAL_6]] : i1 // CHECK: %[[VAL_65:.*]] = scf.if %[[VAL_64]] -> (index) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir index 13a984b8e8d25..f818409994ed3 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -160,6 +160,8 @@ func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_198:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_199:.*]] = arith.constant true // CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> // CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> // CHECK: %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_6]], %[[VAL_7]]] : tensor> @@ -226,13 +228,13 @@ func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, // CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_56]]] : memref // CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index // CHECK: %[[VAL_69:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_68]]] : memref -// CHECK: %[[VAL_70:.*]]:3 = scf.while (%[[VAL_71:.*]] = %[[VAL_64]], %[[VAL_72:.*]] = %[[VAL_67]], %[[VAL_73:.*]] = %[[VAL_5]]) : (index, index, i32) -> (index, index, i32) { +// CHECK: %[[VAL_70:.*]]:4 = scf.while (%[[VAL_71:.*]] = %[[VAL_64]], %[[VAL_72:.*]] = %[[VAL_67]], %[[VAL_73:.*]] = %[[VAL_5]], %[[VAL_200:.*]] = %[[VAL_198]]) : (index, index, i32, i1) -> (index, index, i32, i1) { // CHECK: %[[VAL_74:.*]] = arith.cmpi ult, %[[VAL_71]], %[[VAL_66]] : index // CHECK: %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_72]], %[[VAL_69]] : index // CHECK: %[[VAL_76:.*]] = arith.andi %[[VAL_74]], %[[VAL_75]] : i1 -// CHECK: scf.condition(%[[VAL_76]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, i32 +// CHECK: scf.condition(%[[VAL_76]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]], %[[VAL_200]] : index, index, i32, i1 // CHECK: } do { -// CHECK: ^bb0(%[[VAL_77:.*]]: index, %[[VAL_78:.*]]: index, %[[VAL_79:.*]]: i32): +// CHECK: ^bb0(%[[VAL_77:.*]]: index, %[[VAL_78:.*]]: index, %[[VAL_79:.*]]: i32, %[[VAL_201:.*]]: i1): // CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_77]]] : memref // CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_78]]] : memref // CHECK: %[[VAL_82:.*]] = arith.cmpi ult, %[[VAL_81]], %[[VAL_80]] : index @@ -241,14 +243,14 @@ func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, // CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index // CHECK: %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index // CHECK: %[[VAL_86:.*]] = arith.andi %[[VAL_84]], %[[VAL_85]] : i1 -// CHECK: %[[VAL_87:.*]] = scf.if %[[VAL_86]] -> (i32) { +// CHECK: %[[VAL_87:.*]]:2 = scf.if %[[VAL_86]] -> (i32, i1) { // CHECK: %[[VAL_88:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_77]]] : memref // CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_78]]] : memref // CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_88]], %[[VAL_89]] : i32 // CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_79]], %[[VAL_90]] : i32 -// CHECK: scf.yield %[[VAL_91]] : i32 +// CHECK: scf.yield %[[VAL_91]], %[[VAL_199]] : i32, i1 // CHECK: } else { -// CHECK: scf.yield %[[VAL_79]] : i32 +// CHECK: scf.yield %[[VAL_79]], %[[VAL_201]] : i32, i1 // CHECK: } // CHECK: %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index // CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_77]], %[[VAL_3]] : index @@ -256,9 +258,11 @@ func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20xf32, // CHECK: %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index // CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_78]], %[[VAL_3]] : index // CHECK: %[[VAL_97:.*]] = arith.select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index -// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32 +// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_87]]#0, %[[VAL_87]]#1 : index, index, i32, i1 +// CHECK: } +// CHECK: scf.if %[[VAL_70]]#3 { +// CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor, memref, i32 // CHECK: } -// CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor, memref, i32 // CHECK: } else { // CHECK: } // CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index @@ -317,6 +321,7 @@ func @sumred(%arga: tensor, // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_199:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[VAL_5:.*]] = arith.constant false // CHECK-DAG: %[[VAL_6:.*]] = arith.constant true // CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> @@ -365,10 +370,16 @@ func @sumred(%arga: tensor, // CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_54]]] : memref // CHECK: %[[VAL_56:.*]] = scf.for %[[VAL_57:.*]] = %[[VAL_53]] to %[[VAL_55]] step %[[VAL_3]] iter_args(%[[VAL_58:.*]] = %[[VAL_43]]) -> (index) { // CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_200:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_201:.*]] = scf.if %[[VAL_200]] -> (f32) { +// CHECK: %[[VAL_202:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref +// CHECK: scf.yield %[[VAL_202]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_199]] : f32 +// CHECK: } // CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref // CHECK: %[[VAL_62:.*]] = arith.mulf %[[VAL_52]], %[[VAL_61]] : f32 -// CHECK: %[[VAL_63:.*]] = arith.addf %[[VAL_60]], %[[VAL_62]] : f32 +// CHECK: %[[VAL_63:.*]] = arith.addf %[[VAL_201]], %[[VAL_62]] : f32 // CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref // CHECK: %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_64]], %[[VAL_5]] : i1 // CHECK: %[[VAL_66:.*]] = scf.if %[[VAL_65]] -> (index) {