Skip to content

Commit 7f3ab14

Browse files
authored
Merge pull request #67 from eriknw/select
Implement matrix_select op
2 parents c62b844 + 96933b7 commit 7f3ab14

File tree

5 files changed

+424
-4
lines changed

5 files changed

+424
-4
lines changed

mlir_graphblas/src/lib/GraphBLAS/GraphBLASPasses.cpp

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,139 @@ class LowerMatrixSelectRewrite : public OpRewritePattern<graphblas::MatrixSelect
181181
public:
182182
using OpRewritePattern<graphblas::MatrixSelectOp>::OpRewritePattern;
183183
LogicalResult matchAndRewrite(graphblas::MatrixSelectOp op, PatternRewriter &rewriter) const {
184-
// TODO sanity check that the sparse encoding is sane
185-
return failure();
184+
ModuleOp module = op->getParentOfType<ModuleOp>();
185+
Location loc = op->getLoc();
186+
187+
Value input = op.input();
188+
Type valueType = input.getType().dyn_cast<TensorType>().getElementType();
189+
Type int64Type = rewriter.getIntegerType(64);
190+
FloatType float64Type = rewriter.getF64Type();
191+
Type indexType = rewriter.getIndexType();
192+
Type memref1DI64Type = MemRefType::get({-1}, int64Type);
193+
Type memref1DValueType = MemRefType::get({-1}, valueType);
194+
195+
StringRef selector = op.selector();
196+
197+
bool needs_col = false, needs_val = false;
198+
if (selector == "triu")
199+
{
200+
needs_col = true;
201+
needs_val = false;
202+
}
203+
else if (selector == "tril")
204+
{
205+
needs_col = true;
206+
needs_val = false;
207+
}
208+
else if (selector == "gt0")
209+
{
210+
needs_col = false;
211+
needs_val = true;
212+
}
213+
else
214+
{
215+
return failure();
216+
}
217+
218+
// Initial constants
219+
Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
220+
Value c1 = rewriter.create<ConstantIndexOp>(loc, 1);
221+
Value c0_64 = rewriter.create<ConstantIntOp>(loc, 0, int64Type);
222+
Value c1_64 = rewriter.create<ConstantIntOp>(loc, 1, int64Type);
223+
Value cf0 = rewriter.create<ConstantFloatOp>(loc, APFloat(0.0), float64Type);
224+
225+
// Get sparse tensor info
226+
Value nrow = rewriter.create<memref::DimOp>(loc, input, c0);
227+
Value Ap = rewriter.create<sparse_tensor::ToPointersOp>(loc, memref1DI64Type, input, c1);
228+
Value Aj = rewriter.create<sparse_tensor::ToIndicesOp>(loc, memref1DI64Type, input, c1);
229+
Value Ax = rewriter.create<sparse_tensor::ToValuesOp>(loc, memref1DValueType, input);
230+
231+
Value output = callDupTensor(rewriter, module, loc, input).getResult(0);
232+
Value Bp = rewriter.create<sparse_tensor::ToPointersOp>(loc, memref1DI64Type, output, c1);
233+
Value Bj = rewriter.create<sparse_tensor::ToIndicesOp>(loc, memref1DI64Type, output, c1);
234+
Value Bx = rewriter.create<sparse_tensor::ToValuesOp>(loc, memref1DValueType, output);
235+
236+
rewriter.create<memref::StoreOp>(loc, c0_64, Bp, c0);
237+
// Loop
238+
scf::ForOp outerLoop = rewriter.create<scf::ForOp>(loc, c0, nrow, c1);
239+
Value row = outerLoop.getInductionVar();
240+
241+
rewriter.setInsertionPointToStart(outerLoop.getBody());
242+
Value row_plus1 = rewriter.create<mlir::AddIOp>(loc, row, c1);
243+
Value bp_curr_count = rewriter.create<memref::LoadOp>(loc, Bp, row);
244+
rewriter.create<memref::StoreOp>(loc, bp_curr_count, Bp, row_plus1);
245+
246+
Value j_start_64 = rewriter.create<memref::LoadOp>(loc, Ap, row);
247+
Value j_end_64 = rewriter.create<memref::LoadOp>(loc, Ap, row_plus1);
248+
Value j_start = rewriter.create<mlir::IndexCastOp>(loc, j_start_64, indexType);
249+
Value j_end = rewriter.create<mlir::IndexCastOp>(loc, j_end_64, indexType);
250+
251+
scf::ForOp innerLoop = rewriter.create<scf::ForOp>(loc, j_start, j_end, c1);
252+
253+
Value jj = innerLoop.getInductionVar();
254+
255+
rewriter.setInsertionPointToStart(innerLoop.getBody());
256+
Value col_64, col, val, keep;
257+
if (needs_col)
258+
{
259+
col_64 = rewriter.create<memref::LoadOp>(loc, Aj, jj);
260+
col = rewriter.create<mlir::IndexCastOp>(loc, col_64, indexType);
261+
}
262+
if (needs_val)
263+
{
264+
val = rewriter.create<memref::LoadOp>(loc, Ax, jj);
265+
}
266+
if (selector == "triu")
267+
{
268+
keep = rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ugt, col, row);
269+
}
270+
else if (selector == "tril")
271+
{
272+
keep = rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ult, col, row);
273+
}
274+
else if (selector == "gt0")
275+
{
276+
keep = rewriter.create<mlir::CmpFOp>(loc, mlir::CmpFPredicate::OGT, val, cf0);
277+
}
278+
else
279+
{
280+
return failure();
281+
}
282+
283+
scf::IfOp ifKeep = rewriter.create<scf::IfOp>(loc, keep, false /* no else region */);
284+
285+
rewriter.setInsertionPointToStart(ifKeep.thenBlock());
286+
287+
Value bj_pos_64 = rewriter.create<memref::LoadOp>(loc, Bp, row_plus1);
288+
Value bj_pos = rewriter.create<mlir::IndexCastOp>(loc, bj_pos_64, indexType);
289+
290+
if (!needs_col)
291+
{
292+
col_64 = rewriter.create<memref::LoadOp>(loc, Aj, jj);
293+
}
294+
rewriter.create<memref::StoreOp>(loc, col_64, Bj, bj_pos);
295+
296+
if (!needs_val)
297+
{
298+
val = rewriter.create<memref::LoadOp>(loc, Ax, jj);
299+
}
300+
rewriter.create<memref::StoreOp>(loc, val, Bx, bj_pos);
301+
302+
Value bj_pos_plus1 = rewriter.create<mlir::AddIOp>(loc, bj_pos_64, c1_64);
303+
rewriter.create<memref::StoreOp>(loc, bj_pos_plus1, Bp, row_plus1);
304+
305+
rewriter.setInsertionPointAfter(outerLoop);
306+
307+
// trim excess values
308+
Value nnz_64 = rewriter.create<memref::LoadOp>(loc, Bp, nrow);
309+
Value nnz = rewriter.create<mlir::IndexCastOp>(loc, nnz_64, indexType);
310+
311+
callResizeIndex(rewriter, module, loc, output, c1, nnz);
312+
callResizeValues(rewriter, module, loc, output, nnz);
313+
314+
rewriter.replaceOp(op, output);
315+
316+
return success();
186317
};
187318
};
188319

@@ -394,6 +525,7 @@ class LowerMatrixMultiplyRewrite : public OpRewritePattern<graphblas::MatrixMult
394525

395526
void populateGraphBLASLoweringPatterns(RewritePatternSet &patterns) {
396527
patterns.add<
528+
LowerMatrixSelectRewrite,
397529
LowerMatrixReduceToScalarRewrite,
398530
LowerMatrixMultiplyRewrite,
399531
LowerTransposeRewrite,

mlir_graphblas/src/lowering-test/MatrixSelect.cpp

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/IR/BuiltinOps.h"
44
#include "mlir/IR/BuiltinTypes.h"
55
#include "mlir/Dialect/StandardOps/IR/Ops.h"
6+
#include "mlir/Dialect/SCF/SCF.h"
67
#include "mlir/IR/MLIRContext.h"
78
#include "mlir/IR/Verifier.h"
89
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -18,10 +19,32 @@ void addMatrixSelectFunc(mlir::ModuleOp mod, const std::string &selector)
1819
{
1920
MLIRContext *context = mod.getContext();
2021
OpBuilder builder(mod.getBodyRegion());
22+
auto loc = builder.getUnknownLoc();
2123
builder.setInsertionPointToStart(mod.getBody());
2224

23-
// Create function signature
25+
bool needs_col = false, needs_val = false;
26+
if (selector == "triu") {
27+
needs_col = true;
28+
needs_val = false;
29+
} else if (selector == "tril") {
30+
needs_col = true;
31+
needs_val = false;
32+
} else if (selector == "gt0") {
33+
needs_col = false;
34+
needs_val = true;
35+
} else {
36+
assert(!"invalid selector");
37+
}
38+
39+
// Types
2440
auto valueType = builder.getF64Type();
41+
auto i64Type = builder.getI64Type();
42+
auto f64Type = builder.getF64Type();
43+
auto indexType = builder.getIndexType();
44+
auto memref1DI64Type = MemRefType::get({-1}, i64Type);
45+
auto memref1DValueType = MemRefType::get({-1}, valueType);
46+
47+
// Create function signature
2548
RankedTensorType csrTensor = getCSRTensorType(context, valueType);
2649

2750
string func_name = "matrix_select_" + selector;
@@ -33,8 +56,97 @@ void addMatrixSelectFunc(mlir::ModuleOp mod, const std::string &selector)
3356
auto &entry_block = *func.addEntryBlock();
3457
builder.setInsertionPointToStart(&entry_block);
3558

59+
auto input = entry_block.getArgument(0);
60+
3661
// add function body ops here
62+
// Initial constants
63+
Value c0 = builder.create<ConstantIndexOp>(loc, 0);
64+
Value c1 = builder.create<ConstantIndexOp>(loc, 1);
65+
Value c0_64 = builder.create<ConstantIntOp>(loc, 0, i64Type);
66+
Value c1_64 = builder.create<ConstantIntOp>(loc, 1, i64Type);
67+
Value cf0 = builder.create<ConstantFloatOp>(loc, APFloat(0.0), f64Type);
68+
69+
// Get sparse tensor info
70+
Value nrow = builder.create<memref::DimOp>(loc, input, c0);
71+
Value ncol = builder.create<memref::DimOp>(loc, input, c1);
72+
Value Ap = builder.create<ToPointersOp>(loc, memref1DI64Type, input, c1);
73+
Value Aj = builder.create<ToIndicesOp>(loc, memref1DI64Type, input, c1);
74+
Value Ax = builder.create<ToValuesOp>(loc, memref1DValueType, input);
75+
76+
Value output = callDupTensor(builder, mod, loc, input).getResult(0);
77+
Value Bp = builder.create<ToPointersOp>(loc, memref1DI64Type, output, c1);
78+
Value Bj = builder.create<ToIndicesOp>(loc, memref1DI64Type, output, c1);
79+
Value Bx = builder.create<ToValuesOp>(loc, memref1DValueType, output);
80+
81+
builder.create<memref::StoreOp>(loc, c0_64, Bp, c0);
82+
// Loop
83+
auto outerLoop = builder.create<scf::ForOp>(loc, c0, nrow, c1);
84+
Value row = outerLoop.getInductionVar();
85+
86+
builder.setInsertionPointToStart(outerLoop.getBody());
87+
Value row_plus1 = builder.create<mlir::AddIOp>(loc, row, c1);
88+
Value bp_curr_count = builder.create<memref::LoadOp>(loc, Bp, row);
89+
builder.create<memref::StoreOp>(loc, bp_curr_count, Bp, row_plus1);
90+
91+
Value j_start_64 = builder.create<memref::LoadOp>(loc, Ap, row);
92+
Value j_end_64 = builder.create<memref::LoadOp>(loc, Ap, row_plus1);
93+
Value j_start = builder.create<mlir::IndexCastOp>(loc, j_start_64, indexType);
94+
Value j_end = builder.create<mlir::IndexCastOp>(loc, j_end_64, indexType);
95+
96+
auto innerLoop = builder.create<scf::ForOp>(loc, j_start, j_end, c1);
97+
98+
Value jj = innerLoop.getInductionVar();
99+
100+
builder.setInsertionPointToStart(innerLoop.getBody());
101+
Value col_64, col, val, keep;
102+
if (needs_col) {
103+
col_64 = builder.create<memref::LoadOp>(loc, Aj, jj);
104+
col = builder.create<mlir::IndexCastOp>(loc, col_64, indexType);
105+
}
106+
if (needs_val) {
107+
val = builder.create<memref::LoadOp>(loc, Ax, jj);
108+
}
109+
if (selector == "triu") {
110+
keep = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ugt, col, row);
111+
}
112+
else if (selector == "tril") {
113+
keep = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ult, col, row);
114+
}
115+
else if (selector == "gt0") {
116+
keep = builder.create<mlir::CmpFOp>(loc, mlir::CmpFPredicate::OGT, val, cf0);
117+
}
118+
else {
119+
assert(!"invalid selector");
120+
}
121+
122+
scf::IfOp ifKeep = builder.create<scf::IfOp>(loc, keep, false /* no else region */);
123+
124+
builder.setInsertionPointToStart(ifKeep.thenBlock());
125+
126+
Value bj_pos_64 = builder.create<memref::LoadOp>(loc, Bp, row_plus1);
127+
Value bj_pos = builder.create<mlir::IndexCastOp>(loc, bj_pos_64, indexType);
128+
129+
if (!needs_col) {
130+
col_64 = builder.create<memref::LoadOp>(loc, Aj, jj);
131+
}
132+
builder.create<memref::StoreOp>(loc, col_64, Bj, bj_pos);
133+
134+
if (!needs_val) {
135+
val = builder.create<memref::LoadOp>(loc, Ax, jj);
136+
}
137+
builder.create<memref::StoreOp>(loc, val, Bx, bj_pos);
138+
139+
Value bj_pos_plus1 = builder.create<mlir::AddIOp>(loc, bj_pos_64, c1_64);
140+
builder.create<memref::StoreOp>(loc, bj_pos_plus1, Bp, row_plus1);
141+
142+
builder.setInsertionPointAfter(outerLoop);
143+
144+
Value nnz_64 = builder.create<memref::LoadOp>(loc, Bp, nrow);
145+
Value nnz = builder.create<mlir::IndexCastOp>(loc, nnz_64, indexType);
146+
147+
callResizeIndex(builder, mod, loc, output, c1, nnz);
148+
callResizeValues(builder, mod, loc, output, nnz);
37149

38150
// Add return op
39-
builder.create<ReturnOp>(builder.getUnknownLoc());
151+
builder.create<ReturnOp>(builder.getUnknownLoc(), output);
40152
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: graphblas-opt %s | graphblas-opt --graphblas-lower | FileCheck %s
2+
3+
#CSR64 = #sparse_tensor.encoding<{
4+
dimLevelType = [ "dense", "compressed" ],
5+
dimOrdering = affine_map<(i,j) -> (i,j)>,
6+
pointerBitWidth = 64,
7+
indexBitWidth = 64
8+
}>
9+
10+
// CHECK-LABEL: func @select_gt0(
11+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> {
12+
// CHECK: %[[VAL_1:.*]] = constant 0 : index
13+
// CHECK: %[[VAL_2:.*]] = constant 1 : index
14+
// CHECK: %[[VAL_3:.*]] = constant 0 : i64
15+
// CHECK: %[[VAL_4:.*]] = constant 1 : i64
16+
// CHECK: %[[VAL_5:.*]] = constant 0.000000e+00 : f64
17+
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
18+
// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
19+
// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
20+
// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xf64>
21+
// CHECK: %[[VAL_11:.*]] = call @dup_tensor(%[[VAL_0]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
22+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_11]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
23+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_11]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
24+
// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_11]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xf64>
25+
// CHECK: memref.store %[[VAL_3]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref<?xi64>
26+
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_1]] to %[[VAL_6]] step %[[VAL_2]] {
27+
// CHECK: %[[VAL_16:.*]] = addi %[[VAL_15]], %[[VAL_2]] : index
28+
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]]] : memref<?xi64>
29+
// CHECK: memref.store %[[VAL_17]], %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
30+
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xi64>
31+
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xi64>
32+
// CHECK: %[[VAL_20:.*]] = index_cast %[[VAL_18]] : i64 to index
33+
// CHECK: %[[VAL_21:.*]] = index_cast %[[VAL_19]] : i64 to index
34+
// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_2]] {
35+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref<?xf64>
36+
// CHECK: %[[VAL_24:.*]] = cmpf ogt, %[[VAL_23]], %[[VAL_5]] : f64
37+
// CHECK: scf.if %[[VAL_24]] {
38+
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
39+
// CHECK: %[[VAL_26:.*]] = index_cast %[[VAL_25]] : i64 to index
40+
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xi64>
41+
// CHECK: memref.store %[[VAL_27]], %[[VAL_13]]{{\[}}%[[VAL_26]]] : memref<?xi64>
42+
// CHECK: memref.store %[[VAL_23]], %[[VAL_14]]{{\[}}%[[VAL_26]]] : memref<?xf64>
43+
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_25]], %[[VAL_4]] : i64
44+
// CHECK: memref.store %[[VAL_28]], %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
45+
// CHECK: }
46+
// CHECK: }
47+
// CHECK: }
48+
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<?xi64>
49+
// CHECK: %[[VAL_30:.*]] = index_cast %[[VAL_29]] : i64 to index
50+
// CHECK: call @resize_index(%[[VAL_11]], %[[VAL_2]], %[[VAL_30]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>, index, index) -> ()
51+
// CHECK: call @resize_values(%[[VAL_11]], %[[VAL_30]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>, index) -> ()
52+
// CHECK: return %[[VAL_11]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
53+
// CHECK: }
54+
55+
func @select_gt0(%sparse_tensor: tensor<?x?xf64, #CSR64>) -> tensor<?x?xf64, #CSR64> {
56+
%answer = graphblas.matrix_select %sparse_tensor { selector = "gt0" } : tensor<?x?xf64, #CSR64>
57+
return %answer : tensor<?x?xf64, #CSR64>
58+
}

0 commit comments

Comments
 (0)