Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,179 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Custom Linalg.Generic Operations.
//===----------------------------------------------------------------------===//

def SparseTensor_LinalgIntersectOp : SparseTensor_Op<"linalg_intersect", [NoSideEffect, SameTypeOperands]> {
let summary = "Custom intersect operation within linalg.generic";
let description = [{
Custom intersect operation for use within `linalg.generic`.
The actual operation is held in a block for embedding by the sparse tensor dialect.
The final value in the block must be a sparse_tensor.linalg_yield.

Example:
```mlir
%result = sparse_tensor.linalg_intersect %a, %b : f64 to i8 {
^bb0(%v0: f64, %v1: f64):
%cmp = arith.cmpf "oeq", %v0, %v1 : f64
%ret_i8 = arith.extui %cmp : i1 to i8
sparse_tensor.linalg_yield %ret_i8 : i8
}
```
}];

let arguments = (ins AnyType:$a, AnyType:$b);
let results = (outs AnyType:$output);
let regions = (region SizedRegion<1>:$formula);

let assemblyFormat = [{
$a `,` $b attr-dict `:` type($a) `to` type($output) $formula
}];
let hasVerifier = 1;
}

def SparseTensor_LinalgUnionOp : SparseTensor_Op<"linalg_union", [NoSideEffect, SameTypeOperands]> {
let summary = "Custom union operation within linalg.generic";
let description = [{
Custom union operation for use within `linalg.generic`.
The actual operation is held in a block for embedding by the sparse tensor dialect.
The final value in the block must be a sparse_tensor.linalg_yield.

Example:
```mlir
%result = sparse_tensor.linalg_union %a, %b: f64 to f64 {
^bb0(%v0: f64, %v1: f64):
%cmp = arith.cmpf "olt", %v0, %v1 : f64
%smaller = select %cmp, %v0, %v1 : f64
sparse_tensor.linalg_yield %smaller : f64
}
```
}];

let arguments = (ins AnyType:$a, AnyType:$b);
let results = (outs AnyType:$output);
let regions = (region SizedRegion<1>:$formula);

let assemblyFormat = [{
$a `,` $b attr-dict `:` type($a) `to` type($output) $formula
}];
let hasVerifier = 1;
}

def SparseTensor_LinalgReduceOp : SparseTensor_Op<"linalg_reduce", [NoSideEffect, SameTypeOperands]> {
let summary = "Custom reduce operation within linalg.generic";
let description = [{
Custom reduce operation for use within `linalg.generic`.
The actual operation is held in a block for embedding by the sparse tensor dialect.
The final value in the formula block must be a sparse_tensor.linalg_yield.

A separate "init" block contains the starting value for the reduction.

Example:
```mlir
%result = sparse_tensor.linalg_reduce %a, %b : f64 to f64 {
^bb0(%v0: f64, %v1: f64):
%ret = arith.addf %v0, %v1 : f64
sparse_tensor.linalg_yield %ret : f64
} init {
%init = arith.constant 0.0 : f64
sparse_tensor.linalg_yield %init : f64
}
```
}];

let arguments = (ins AnyType:$a, AnyType:$b);
let results = (outs AnyType:$output);
let regions = (region SizedRegion<1>:$formula, SizedRegion<1>:$init);

let assemblyFormat = [{
$a `,` $b attr-dict `:` type($a) `to` type($output) $formula `init` $init
}];
let hasVerifier = 1;
}

def SparseTensor_LinalgApplyOp : SparseTensor_Op<"linalg_apply", [NoSideEffect]> {
let summary = "Custom apply operation within linalg.generic";
let description = [{
Custom apply operation for use within `linalg.generic`.
The actual operation is held in a block for embedding by the sparse tensor dialect.
The final value in the formula block must be a sparse_tensor.linalg_yield.

Example:
```
%result = sparse_tensor.linalg_apply %a : f64 to f64 {
^bb0(%v0: f64):
%cf1 = arith.constant 1.0 : f64
%ret = arith.addf %v0, %cf1 : f64
sparse_tensor.linalg_yield %ret : f64
}
```
}];

let arguments = (ins AnyType:$a);
let results = (outs AnyType:$output);
let regions = (region SizedRegion<1>:$formula);

let assemblyFormat = [{
$a attr-dict `:` type($a) `to` type($output) $formula
}];
let hasVerifier = 1;
}

def SparseTensor_LinalgMaskOp : SparseTensor_Op<"linalg_mask", []> {
let summary = "Mask operation within linalg.generic";
let description = [{
Custom mask operation for use within `linalg.generic`.
The mask operation must be the first operation in the linalg.generic block.

The final value in the mask block must be a sparse_tensor.linalg_yield
and must return a boolean (i1) result. If the return value is false,
the output will be masked (i.e. empty).
The meaning of each block argument depends on the rank of the linalg.generic output tensor.

Rank 1 (Vector)
- arg0 : index
- arg1 : value

Rank 2 (Matrix)
- arg0 : row
- arg1 : column
- arg2 : value

Example:
```
sparse_tensor.linalg_mask {
^bb0(%row: index, col%: index):
%triu = arith.cmpi "ugt", %col, %row : index
sparse_tensor.linalg_yield %triu : i1
}
```
}];

let regions = (region SizedRegion<1>:$expr);

let assemblyFormat = [{
$expr attr-dict
}];
let hasVerifier = 1;
}

def SparseTensor_LinalgYieldOp : SparseTensor_Op<"linalg_yield", [Terminator]> {
let summary = "Yield from sparse_tensor.linalg_* methods";
let description = [{
Yield a value from within a custom block

Example:
```
sparse_tensor.linalg_yield %result : f64
```
}];

let arguments = (ins AnyType:$result);
let assemblyFormat = [{
$result attr-dict `:` type($result)
}];
}

#endif // SPARSETENSOR_OPS
31 changes: 22 additions & 9 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ enum Kind {
kCastU, // unsigned
kTruncI,
kBitCast,
// Custom unary
kApply,
// Binary operations.
kMulF,
kMulI,
Expand All @@ -60,6 +62,10 @@ enum Kind {
kShrS, // signed
kShrU, // unsigned
kShlI,
// Custom binary
kIntersect,
kUnion,
kReduce,
};

/// Children subexpressions of tensor operations.
Expand All @@ -70,7 +76,7 @@ struct Children {

/// Tensor expression. Represents a MLIR expression in tensor index notation.
struct TensorExp {
TensorExp(Kind k, unsigned x, unsigned y, Value v);
TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *op);

/// Tensor expression kind.
Kind kind;
Expand All @@ -87,6 +93,8 @@ struct TensorExp {
/// infer destination type) of a cast operation During code generation,
/// this field may be used to cache "hoisted" loop invariant tensor loads.
Value val;

Operation *operation;
};

/// Lattice point. Each lattice point consists of a conjunction of tensor
Expand Down Expand Up @@ -125,9 +133,9 @@ class Merger {
hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}

/// Adds a tensor expression. Returns its index.
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
unsigned addExp(Kind k, unsigned e, Value v) { return addExp(k, e, -1u, v); }
unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), Operation *op = nullptr);
unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) { return addExp(k, e, -1u, v, op); }
unsigned addExp(Kind k, Value v, Operation *op = nullptr) { return addExp(k, -1u, -1u, v, op); }

/// Adds an iteration lattice point. Returns its index.
unsigned addLat(unsigned t, unsigned i, unsigned e);
Expand All @@ -139,20 +147,20 @@ class Merger {
/// of loop indices (effectively constructing a larger "intersection" of those
/// indices) with a newly constructed tensor (sub)expression of given kind.
/// Returns the index of the new lattice point.
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1);
unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, Operation *op = nullptr);

/// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
/// cartesian product. Returns the index of the new set.
unsigned takeConj(Kind kind, unsigned s0, unsigned s1);
unsigned takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op = nullptr);

/// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
/// Returns the index of the new set.
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op = nullptr);

/// Maps the unary operator over the lattice set of the operand, i.e. each
/// lattice point on an expression E is simply copied over, but with OP E
/// as new expression. Returns the index of the new set.
unsigned mapSet(Kind kind, unsigned s0, Value v = Value());
unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), Operation *op = nullptr);

/// Optimizes the iteration lattice points in the given set. This
/// method should be called right before code generation to avoid
Expand Down Expand Up @@ -225,13 +233,18 @@ class Merger {
/// Returns index of the root expression.
unsigned buildLattices(unsigned e, unsigned i);

/// Returns the identity value (i.e. x op identity == x)
/// This value is used in reductions as the initial value, meant to have
/// no impact on the final reduction value.
Value getIdentity(PatternRewriter &rewriter, Location loc, unsigned e, Type tp);

/// Builds a tensor expression from the given Linalg operation.
/// Returns index of the root expression on success.
Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);

/// Rebuilds SSA format from a tensor expression.
Value buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0,
Value v1);
Value v1, std::vector<Value> idxs);

private:
/// Private helpers.
Expand Down
109 changes: 109 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,115 @@ LogicalResult OutOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Custom Linalg.Generic Operations.
//===----------------------------------------------------------------------===//

LogicalResult LinalgIntersectOp::verify() {
Region &region = formula();
Block &formula = region.front();
if (formula.getNumArguments() != 2)
return emitError("block must have 2 arguments");

Type outputType = output().getType();
LinalgYieldOp yield =
llvm::dyn_cast_or_null<LinalgYieldOp>(formula.getTerminator());
if (yield == nullptr)
return emitError("intersect block must end with sparse_tensor.linalg_yield");
Value retVal = yield.result();
if (retVal.getType() != outputType)
return emitError("yield value in block does not match intersect return type");

return success();
}

LogicalResult LinalgUnionOp::verify() {
Region &region = formula();
Block &formula = region.front();
if (formula.getNumArguments() != 2)
return emitError("block must have 2 arguments");

Type outputType = output().getType();
LinalgYieldOp yield =
llvm::dyn_cast_or_null<LinalgYieldOp>(formula.getTerminator());
if (yield == nullptr)
return emitError("union block must end with sparse_tensor.linalg_yield");
Value retVal = yield.result();
if (retVal.getType() != outputType)
return emitError("yield value in block does not match union return type");

return success();
}

LogicalResult LinalgReduceOp::verify() {
Region &formulaRegion = formula();
Block &formula = formulaRegion.front();
if (formula.getNumArguments() != 2)
return emitError("formula block must have 2 arguments");

Type outputType = output().getType();
LinalgYieldOp yield =
llvm::dyn_cast_or_null<LinalgYieldOp>(formula.getTerminator());
if (yield == nullptr)
return emitError("reduce block must end with sparse_tensor.linalg_yield");
Value retVal = yield.result();
if (retVal.getType() != outputType)
return emitError("yield value in formula block does not match reduce return type");

Region &initRegion = init();
Block &init = initRegion.front();
if (init.getNumArguments() != 0)
return emitError("init block must not have any arguments");

LinalgYieldOp initYield =
llvm::dyn_cast_or_null<LinalgYieldOp>(init.getTerminator());
if (initYield == nullptr)
return emitError("init block must end with sparse_tensor.linalg_yield");
Value initVal = initYield.result();
if (initVal.getType() != outputType)
return emitError("yield value in init block does not match reduce return type");

return success();
}

LogicalResult LinalgApplyOp::verify() {
Region &region = formula();
Block &formula = region.front();
if (formula.getNumArguments() < 1)
return emitError("block must have at least 1 argument");
if (formula.getNumArguments() > 3)
return emitError("block must have no more than 3 arguments");

Type outputType = output().getType();
LinalgYieldOp yield =
llvm::dyn_cast_or_null<LinalgYieldOp>(formula.getTerminator());
if (yield == nullptr)
return emitError("apply block must end with sparse_tensor.linalg_yield");

Value retVal = yield.result();
if (retVal.getType() != outputType)
return emitError("yield value in block does not match apply return type");

return success();
}

LogicalResult LinalgMaskOp::verify() {
// Result of block must be i1
Region &region = expr();
Block &block = region.front();
LinalgYieldOp yield =
llvm::dyn_cast_or_null<LinalgYieldOp>(block.getTerminator());
if (yield == nullptr)
return emitError("mask block must end with sparse_tensor.linalg_yield");

Type retType = yield.result().getType();
IntegerType iType = retType.dyn_cast<IntegerType>();
if (!iType || iType.getWidth() != 1)
return emitError("mask block must return i1 type");

return success();
}

//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//
Expand Down
Loading