Skip to content

Commit c62b844

Browse files
Add verifiers and tests for all GraphBLAS ops (#65)
1 parent a4856b0 commit c62b844

File tree

12 files changed

+1152
-70
lines changed

12 files changed

+1152
-70
lines changed

mlir_graphblas/src/include/GraphBLAS/GraphBLASDialect.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,4 @@ def GraphBLAS_Dialect : Dialect {
2222
let cppNamespace = "::mlir::graphblas";
2323
}
2424

25-
//===--------------------------------------------------------------------===//
26-
// Base graphblas operation definition.
27-
//===--------------------------------------------------------------------===//
28-
29-
class GraphBLAS_Op<string mnemonic, list<OpTrait> traits = []> :
30-
Op<GraphBLAS_Dialect, mnemonic, traits>;
31-
3225
#endif // GRAPHBLAS_DIALECT

mlir_graphblas/src/include/GraphBLAS/GraphBLASOps.td

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
include "GraphBLASDialect.td"
1111
include "mlir/Interfaces/SideEffectInterfaces.td"
1212

13+
class GraphBLAS_Op<string mnemonic, list<OpTrait> traits = []> : Op<GraphBLAS_Dialect, mnemonic, traits> {
14+
let verifier = [{ return ::verify(*this); }];
15+
}
16+
1317
def GraphBLAS_TransposeOp : GraphBLAS_Op<"transpose", [NoSideEffect]> {
1418
let summary = "transpose operation";
1519
let description = [{
@@ -31,9 +35,6 @@ def GraphBLAS_TransposeOp : GraphBLAS_Op<"transpose", [NoSideEffect]> {
3135
let assemblyFormat = [{
3236
$input attr-dict `:` type($input) `to` type($output)
3337
}];
34-
35-
// TODO add custom verifier sanity checking the input and output types are sane
36-
// let verifier =
3738
}
3839

3940
def GraphBLAS_MatrixSelectOp : GraphBLAS_Op<"matrix_select", [NoSideEffect, SameOperandsAndResultType]> {
@@ -55,9 +56,6 @@ def GraphBLAS_MatrixSelectOp : GraphBLAS_Op<"matrix_select", [NoSideEffect, Same
5556
let assemblyFormat = [{
5657
$input attr-dict `:` type($input)
5758
}];
58-
59-
// TODO add custom verifier sanity checking the selector attribute is sane
60-
// let verifier =
6159
}
6260

6361
def GraphBLAS_MatrixReduceToScalarOp : GraphBLAS_Op<"matrix_reduce_to_scalar", [NoSideEffect]> {
@@ -78,9 +76,6 @@ def GraphBLAS_MatrixReduceToScalarOp : GraphBLAS_Op<"matrix_reduce_to_scalar", [
7876
let assemblyFormat = [{
7977
$input attr-dict `:` type($input) `to` type($output)
8078
}];
81-
82-
// TODO add custom verifier sanity checking the type of $output is sane ; sanity check the aggregator
83-
// let verifier =
8479
}
8580

8681
def GraphBLAS_MatrixApplyOp : GraphBLAS_Op<"matrix_apply", [NoSideEffect]> {
@@ -105,9 +100,6 @@ def GraphBLAS_MatrixApplyOp : GraphBLAS_Op<"matrix_apply", [NoSideEffect]> {
105100
let assemblyFormat = [{
106101
$input `,` $thunk attr-dict `:` `(` type($input) `,` type($thunk) `)` `to` type($output)
107102
}];
108-
109-
// TODO add custom verifier sanity checking the types of $output and $thunk are sane ; sanity check the apply operator
110-
// let verifier =
111103
}
112104

113105
def GraphBLAS_MatrixMultiplyOp : GraphBLAS_Op<"matrix_multiply", [NoSideEffect]> {
@@ -131,9 +123,6 @@ def GraphBLAS_MatrixMultiplyOp : GraphBLAS_Op<"matrix_multiply", [NoSideEffect]>
131123
let assemblyFormat = [{
132124
$a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output)
133125
}];
134-
135-
// TODO add custom verifier sanity checking the types of the inputs and outputs and mask are sane ; sanity check the semiring
136-
// let verifier =
137126
}
138127

139128
#endif // GRAPHBLAS_OPS

mlir_graphblas/src/include/GraphBLAS/GraphBLASUtils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#include "mlir/IR/BuiltinOps.h"
66
#include "mlir/Dialect/StandardOps/IR/Ops.h"
77
#include "mlir/Dialect/Tensor/IR/Tensor.h"
8+
#include "llvm/ADT/APInt.h"
89

910

10-
mlir::RankedTensorType getCSRTensorType(mlir::MLIRContext *context, mlir::Type valueType);
11+
mlir::RankedTensorType getCSRTensorType(mlir::MLIRContext *context, llvm::ArrayRef<int64_t> shape, mlir::Type valueType);
1112
mlir::CallOp callEmptyLike(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value tensor);
1213
mlir::CallOp callDupTensor(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value tensor);
1314

@@ -20,4 +21,4 @@ mlir::CallOp callResizeIndex(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir
2021
mlir::CallOp callResizeValues(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc,
2122
mlir::Value tensor, mlir::Value size);
2223

23-
#endif // GRAPHBLAS_GRAPHBLASUTILS_H
24+
#endif // GRAPHBLAS_GRAPHBLASUTILS_H

mlir_graphblas/src/lib/GraphBLAS/GraphBLASOps.cpp

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,281 @@
44
//
55
//===--------------------------------------------------------------------===//
66

7+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
78
#include "GraphBLAS/GraphBLASOps.h"
89
#include "GraphBLAS/GraphBLASDialect.h"
910
#include "mlir/IR/OpImplementation.h"
11+
#include "llvm/ADT/Optional.h"
12+
#include "llvm/ADT/None.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::graphblas;
16+
17+
//===--------------------------------------------------------------------===//
18+
// Helpers
19+
//===--------------------------------------------------------------------===//
20+
21+
enum CompressionType { CSR, CSC, EITHER };
22+
23+
static llvm::Optional<std::string> checkCompressedSparseTensor(
24+
Type inputType,
25+
int inputIndex,
26+
CompressionType compressionType
27+
) {
28+
/*
29+
Negative values for inputIndex indicate that the input type is the return type.
30+
Otherwise, inputIndex indicates which arg inputType corresponds to.
31+
32+
Returns llvm::None if the given tensor is valid.
33+
Returns a string explaining the problem otherwise.
34+
*/
35+
36+
std::string inputName = inputIndex < 0 ? "Return value" : "Operand #"+std::to_string(inputIndex);
37+
38+
mlir::sparse_tensor::SparseTensorEncodingAttr sparseEncoding =
39+
mlir::sparse_tensor::getSparseTensorEncoding(inputType);
40+
if (!sparseEncoding)
41+
return inputName+" must be a sparse tensor.";
42+
43+
RankedTensorType inputTensorType = inputType.dyn_cast<RankedTensorType>();
44+
if (inputTensorType.getRank() != 2)
45+
return inputName+" must have rank 2.";
46+
47+
ArrayRef<mlir::sparse_tensor::SparseTensorEncodingAttr::DimLevelType> compression =
48+
sparseEncoding.getDimLevelType();
49+
if (compression[0] != mlir::sparse_tensor::SparseTensorEncodingAttr::DimLevelType::Dense ||
50+
compression[1] != mlir::sparse_tensor::SparseTensorEncodingAttr::DimLevelType::Compressed)
51+
return inputName+" must have CSR or CSC compression, i.e. must have "
52+
"dimLevelType = [ \"dense\", \"compressed\" ] in the sparse encoding.";
53+
54+
if (compressionType != EITHER) {
55+
56+
AffineMap dimOrdering = sparseEncoding.getDimOrdering();
57+
unsigned dimOrdering0 = dimOrdering.getDimPosition(0);
58+
unsigned dimOrdering1 = dimOrdering.getDimPosition(1);
59+
60+
assert(compressionType == CSR || compressionType == CSC);
61+
62+
if (compressionType == CSR) {
63+
if (dimOrdering0 != 0 || dimOrdering1 != 1)
64+
return inputName+" must have CSR compression.";
65+
} else if (compressionType == CSC) {
66+
if (dimOrdering0 != 1 || dimOrdering1 != 0)
67+
return inputName+" must have CSC compression.";
68+
}
69+
}
70+
71+
return llvm::None;
72+
}
73+
74+
//===--------------------------------------------------------------------===//
75+
// GraphBLAS Ops Methods
76+
//===--------------------------------------------------------------------===//
77+
78+
static LogicalResult verify(MatrixApplyOp op) {
79+
Type inputType = op.input().getType();
80+
Type thunkType = op.thunk().getType();
81+
Type resultType = op.getResult().getType();
82+
83+
llvm::Optional<std::string> inputCompressionErrorMessage = checkCompressedSparseTensor(inputType, 0, EITHER);
84+
if (inputCompressionErrorMessage)
85+
return op.emitError(inputCompressionErrorMessage.getValue());
86+
87+
llvm::Optional<std::string> resultCompressionErrorMessage = checkCompressedSparseTensor(resultType, -1, EITHER);
88+
if (resultCompressionErrorMessage)
89+
return op.emitError(resultCompressionErrorMessage.getValue());
90+
91+
RankedTensorType inputTensorType = inputType.dyn_cast<RankedTensorType>();
92+
RankedTensorType resultTensorType = resultType.dyn_cast<RankedTensorType>();
93+
94+
if (inputTensorType.getElementType() != thunkType)
95+
return op.emitError("Element type of input tensor does not match type of thunk.");
96+
97+
if (resultTensorType.getElementType() != thunkType)
98+
// TODO this is not always correct, e.g. matrix_apply_less_than(tensor<f64>, 2.3) -> tensor<i1>.
99+
return op.emitError("Element type of result tensor does not match type of thunk.");
100+
101+
ArrayRef<int64_t> inputShape = inputTensorType.getShape();
102+
ArrayRef<int64_t> resultShape = resultTensorType.getShape();
103+
104+
// TODO intelligently handle arbitrarily shaped tensors, i.e. tensors with shapes using "?"
105+
if (inputShape[0] != resultShape[0] || inputShape[1] != resultShape[1])
106+
return op.emitError("Input shape does not match output shape.");
107+
108+
static const std::vector<std::string> supportedOperators{"min"};
109+
std::string applyOperator = op.apply_operator().str();
110+
bool operatorSupported = std::find(supportedOperators.begin(), supportedOperators.end(), applyOperator)
111+
!= supportedOperators.end();
112+
if (!operatorSupported)
113+
return op.emitError("\""+applyOperator+"\" is not a supported operator.");
114+
115+
return success();
116+
}
117+
118+
static LogicalResult verify(MatrixMultiplyOp op) {
119+
Type aType = op.a().getType();
120+
Type bType = op.b().getType();
121+
Type resultType = op.getResult().getType();
122+
123+
llvm::Optional<std::string> aCompressionErrorMessage = checkCompressedSparseTensor(aType, 0, CSR);
124+
if (aCompressionErrorMessage)
125+
return op.emitError(aCompressionErrorMessage.getValue());
126+
127+
llvm::Optional<std::string> bCompressionErrorMessage = checkCompressedSparseTensor(bType, 1, CSC);
128+
if (bCompressionErrorMessage)
129+
return op.emitError(bCompressionErrorMessage.getValue());
130+
131+
llvm::Optional<std::string> resultCompressionErrorMessage = checkCompressedSparseTensor(resultType, -1, CSR);
132+
if (resultCompressionErrorMessage)
133+
return op.emitError(resultCompressionErrorMessage.getValue());
134+
135+
static const std::vector<std::string> supportedSemirings{"plus_times", "plus_pair", "plus_plus"};
136+
std::string semiring = op.semiring().str();
137+
bool semiringSupported = std::find(supportedSemirings.begin(), supportedSemirings.end(), semiring)
138+
!= supportedSemirings.end();
139+
if (!semiringSupported)
140+
return op.emitError("\""+semiring+"\" is not a supported semiring.");
141+
142+
RankedTensorType aTensorType = aType.dyn_cast<RankedTensorType>();
143+
RankedTensorType bTensorType = bType.dyn_cast<RankedTensorType>();
144+
RankedTensorType resultTensorType = resultType.dyn_cast<RankedTensorType>();
145+
146+
ArrayRef<int64_t> aShape = aTensorType.getShape();
147+
ArrayRef<int64_t> bShape = bTensorType.getShape();
148+
ArrayRef<int64_t> resultShape = resultTensorType.getShape();
149+
// TODO intelligently handle arbitrarily shaped tensors, i.e. tensors with shapes using "?"
150+
if (aShape[1] != bShape[0])
151+
return op.emitError("Operand shapes are incompatible.");
152+
if (resultShape[0] != aShape[0] || resultShape[1] != bShape[1])
153+
return op.emitError("Operand shapes incompatible with output shape.");
154+
155+
if (aTensorType.getElementType() != bTensorType.getElementType())
156+
return op.emitError("Operand element types must be identical.");
157+
if (aTensorType.getElementType() != resultTensorType.getElementType())
158+
return op.emitError("Result element type differs from the input element types.");
159+
160+
Value mask = op.mask();
161+
if (mask) {
162+
Type maskType = mask.getType();
163+
llvm::Optional<std::string> maskCompressionErrorMessage = checkCompressedSparseTensor(maskType, 2, CSR);
164+
if (maskCompressionErrorMessage)
165+
return op.emitError(maskCompressionErrorMessage.getValue());
166+
167+
RankedTensorType maskTensorType = maskType.dyn_cast<RankedTensorType>();
168+
ArrayRef<int64_t> maskShape = maskTensorType.getShape();
169+
if (resultShape[0] != maskShape[0] || resultShape[1] != maskShape[1])
170+
return op.emitError("Mask shape must match output shape.");
171+
}
172+
173+
return success();
174+
}
175+
176+
static LogicalResult verify(MatrixReduceToScalarOp op) {
177+
Type operandType = op.input().getType();
178+
179+
llvm::Optional<std::string> compressionErrorMessage = checkCompressedSparseTensor(operandType, 0, EITHER);
180+
if (compressionErrorMessage)
181+
return op.emitError(compressionErrorMessage.getValue());
182+
183+
static const std::vector<std::string> supportedAggregators{"sum"};
184+
std::string aggregator = op.aggregator().str();
185+
bool aggregatorSupported = std::find(supportedAggregators.begin(), supportedAggregators.end(), aggregator)
186+
!= supportedAggregators.end();
187+
if (!aggregatorSupported)
188+
return op.emitError("\""+aggregator+"\" is not a supported aggregator.");
189+
190+
Type resultType = op.getResult().getType();
191+
RankedTensorType operandTensorType = operandType.dyn_cast<RankedTensorType>();
192+
if (resultType != operandTensorType.getElementType())
193+
return op.emitError("Operand and output types are incompatible.");
194+
195+
return success();
196+
}
197+
198+
static LogicalResult verify(MatrixSelectOp op) {
199+
// input and result types are already guaranteed to be the same
200+
Type resultType = op.getResult().getType();
201+
202+
llvm::Optional<std::string> resultCompressionErrorMessage = checkCompressedSparseTensor(resultType, -1, EITHER);
203+
if (resultCompressionErrorMessage)
204+
return op.emitError(resultCompressionErrorMessage.getValue());
205+
206+
static const std::vector<std::string> supportedSelectors{"triu", "tril", "gt0"};
207+
std::string selector = op.selector().str();
208+
bool selectorSupported = std::find(supportedSelectors.begin(), supportedSelectors.end(), selector)
209+
!= supportedSelectors.end();
210+
if (!selectorSupported)
211+
return op.emitError("\""+selector+"\" is not a supported selector.");
212+
213+
return success();
214+
}
215+
216+
static LogicalResult verify(TransposeOp op) {
217+
Type inputType = op.input().getType();
218+
Type resultType = op.getResult().getType();
219+
220+
llvm::Optional<std::string> inputCompressionErrorMessage = checkCompressedSparseTensor(inputType, 0, EITHER);
221+
if (inputCompressionErrorMessage)
222+
return op.emitError(inputCompressionErrorMessage.getValue());
223+
224+
llvm::Optional<std::string> resultCompressionErrorMessage = checkCompressedSparseTensor(resultType, -1, EITHER);
225+
if (resultCompressionErrorMessage)
226+
return op.emitError(resultCompressionErrorMessage.getValue());
227+
228+
// TODO intelligently handle arbitrarily shaped tensors, i.e. tensors with shapes using "?"
229+
230+
RankedTensorType inputTensorType = inputType.dyn_cast<RankedTensorType>();
231+
RankedTensorType resultTensorType = resultType.dyn_cast<RankedTensorType>();
232+
233+
if (inputTensorType.getElementType() != resultTensorType.getElementType())
234+
return op.emitError("Input and output tensors have different element types.");
235+
236+
ArrayRef<int64_t> inputShape = inputTensorType.getShape();
237+
ArrayRef<int64_t> resultShape = resultTensorType.getShape();
238+
239+
mlir::sparse_tensor::SparseTensorEncodingAttr inputSparseEncoding =
240+
mlir::sparse_tensor::getSparseTensorEncoding(inputType);
241+
242+
mlir::sparse_tensor::SparseTensorEncodingAttr resultSparseEncoding =
243+
mlir::sparse_tensor::getSparseTensorEncoding(resultType);
244+
245+
bool swapSizes = op.swap_sizes();
246+
if (swapSizes) {
247+
if (inputShape[0] != resultShape[1] || inputShape[1] != resultShape[0])
248+
return op.emitError("Input and output shapes are expected to be swapped.");
249+
if (inputSparseEncoding != resultSparseEncoding)
250+
return op.emitError("Input and output tensors are expected to have identical sparse encodings.");
251+
} else {
252+
if (inputShape[0] != resultShape[0] || inputShape[1] != resultShape[1])
253+
return op.emitError("Input and output shapes are expected to be the same.");
254+
255+
AffineMap inputDimOrdering = inputSparseEncoding.getDimOrdering();
256+
AffineMap resultDimOrdering = resultSparseEncoding.getDimOrdering();
257+
unsigned inputDimOrdering0 = inputDimOrdering.getDimPosition(0);
258+
unsigned inputDimOrdering1 = inputDimOrdering.getDimPosition(1);
259+
unsigned resultDimOrdering0 = resultDimOrdering.getDimPosition(0);
260+
unsigned resultDimOrdering1 = resultDimOrdering.getDimPosition(1);
261+
if (inputDimOrdering0 != resultDimOrdering1 || inputDimOrdering1 != resultDimOrdering0)
262+
return op.emitError("Sparse encoding dimension orderings of input and result tensors "
263+
"expected to be swapped.");
264+
265+
// TODO should we be more lenient like the sparse tensor dialect is via isMatchingWidth?
266+
// see llvm-project/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
267+
unsigned inputPointerBitWidth = inputSparseEncoding.getPointerBitWidth();
268+
unsigned resultPointerBitWidth = resultSparseEncoding.getPointerBitWidth();
269+
if (inputPointerBitWidth != resultPointerBitWidth)
270+
return op.emitError("Sparse encoding pointer bit widths of input and result tensors do not match.");
271+
272+
unsigned inputIndexBitWidth = inputSparseEncoding.getIndexBitWidth();
273+
unsigned resultIndexBitWidth = resultSparseEncoding.getIndexBitWidth();
274+
if (inputIndexBitWidth != resultIndexBitWidth)
275+
return op.emitError("Sparse encoding index bit widths of input and result tensors do not match.");
276+
277+
// dimLevelType values guaranteed to be the same since we already checked earlier
278+
}
279+
280+
return success();
281+
}
10282

11283
#define GET_OP_CLASSES
12284
#include "GraphBLAS/GraphBLASOps.cpp.inc"

0 commit comments

Comments
 (0)