Skip to content

Commit 96933b7

Browse files
committed
move code to rewrite pass
1 parent c4f6868 commit 96933b7

File tree

4 files changed

+310
-2
lines changed

4 files changed

+310
-2
lines changed

mlir_graphblas/src/lib/GraphBLAS/GraphBLASPasses.cpp

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,139 @@ class LowerMatrixSelectRewrite : public OpRewritePattern<graphblas::MatrixSelect
182182
public:
183183
using OpRewritePattern<graphblas::MatrixSelectOp>::OpRewritePattern;
184184
LogicalResult matchAndRewrite(graphblas::MatrixSelectOp op, PatternRewriter &rewriter) const {
185-
// TODO sanity check that the sparse encoding is sane
186-
return failure();
185+
ModuleOp module = op->getParentOfType<ModuleOp>();
186+
Location loc = op->getLoc();
187+
188+
Value input = op.input();
189+
Type valueType = input.getType().dyn_cast<TensorType>().getElementType();
190+
Type int64Type = rewriter.getIntegerType(64);
191+
FloatType float64Type = rewriter.getF64Type();
192+
Type indexType = rewriter.getIndexType();
193+
Type memref1DI64Type = MemRefType::get({-1}, int64Type);
194+
Type memref1DValueType = MemRefType::get({-1}, valueType);
195+
196+
StringRef selector = op.selector();
197+
198+
bool needs_col = false, needs_val = false;
199+
if (selector == "triu")
200+
{
201+
needs_col = true;
202+
needs_val = false;
203+
}
204+
else if (selector == "tril")
205+
{
206+
needs_col = true;
207+
needs_val = false;
208+
}
209+
else if (selector == "gt0")
210+
{
211+
needs_col = false;
212+
needs_val = true;
213+
}
214+
else
215+
{
216+
return failure();
217+
}
218+
219+
// Initial constants
220+
Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
221+
Value c1 = rewriter.create<ConstantIndexOp>(loc, 1);
222+
Value c0_64 = rewriter.create<ConstantIntOp>(loc, 0, int64Type);
223+
Value c1_64 = rewriter.create<ConstantIntOp>(loc, 1, int64Type);
224+
Value cf0 = rewriter.create<ConstantFloatOp>(loc, APFloat(0.0), float64Type);
225+
226+
// Get sparse tensor info
227+
Value nrow = rewriter.create<memref::DimOp>(loc, input, c0);
228+
Value Ap = rewriter.create<sparse_tensor::ToPointersOp>(loc, memref1DI64Type, input, c1);
229+
Value Aj = rewriter.create<sparse_tensor::ToIndicesOp>(loc, memref1DI64Type, input, c1);
230+
Value Ax = rewriter.create<sparse_tensor::ToValuesOp>(loc, memref1DValueType, input);
231+
232+
Value output = callDupTensor(rewriter, module, loc, input).getResult(0);
233+
Value Bp = rewriter.create<sparse_tensor::ToPointersOp>(loc, memref1DI64Type, output, c1);
234+
Value Bj = rewriter.create<sparse_tensor::ToIndicesOp>(loc, memref1DI64Type, output, c1);
235+
Value Bx = rewriter.create<sparse_tensor::ToValuesOp>(loc, memref1DValueType, output);
236+
237+
rewriter.create<memref::StoreOp>(loc, c0_64, Bp, c0);
238+
// Loop
239+
scf::ForOp outerLoop = rewriter.create<scf::ForOp>(loc, c0, nrow, c1);
240+
Value row = outerLoop.getInductionVar();
241+
242+
rewriter.setInsertionPointToStart(outerLoop.getBody());
243+
Value row_plus1 = rewriter.create<mlir::AddIOp>(loc, row, c1);
244+
Value bp_curr_count = rewriter.create<memref::LoadOp>(loc, Bp, row);
245+
rewriter.create<memref::StoreOp>(loc, bp_curr_count, Bp, row_plus1);
246+
247+
Value j_start_64 = rewriter.create<memref::LoadOp>(loc, Ap, row);
248+
Value j_end_64 = rewriter.create<memref::LoadOp>(loc, Ap, row_plus1);
249+
Value j_start = rewriter.create<mlir::IndexCastOp>(loc, j_start_64, indexType);
250+
Value j_end = rewriter.create<mlir::IndexCastOp>(loc, j_end_64, indexType);
251+
252+
scf::ForOp innerLoop = rewriter.create<scf::ForOp>(loc, j_start, j_end, c1);
253+
254+
Value jj = innerLoop.getInductionVar();
255+
256+
rewriter.setInsertionPointToStart(innerLoop.getBody());
257+
Value col_64, col, val, keep;
258+
if (needs_col)
259+
{
260+
col_64 = rewriter.create<memref::LoadOp>(loc, Aj, jj);
261+
col = rewriter.create<mlir::IndexCastOp>(loc, col_64, indexType);
262+
}
263+
if (needs_val)
264+
{
265+
val = rewriter.create<memref::LoadOp>(loc, Ax, jj);
266+
}
267+
if (selector == "triu")
268+
{
269+
keep = rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ugt, col, row);
270+
}
271+
else if (selector == "tril")
272+
{
273+
keep = rewriter.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ult, col, row);
274+
}
275+
else if (selector == "gt0")
276+
{
277+
keep = rewriter.create<mlir::CmpFOp>(loc, mlir::CmpFPredicate::OGT, val, cf0);
278+
}
279+
else
280+
{
281+
return failure();
282+
}
283+
284+
scf::IfOp ifKeep = rewriter.create<scf::IfOp>(loc, keep, false /* no else region */);
285+
286+
rewriter.setInsertionPointToStart(ifKeep.thenBlock());
287+
288+
Value bj_pos_64 = rewriter.create<memref::LoadOp>(loc, Bp, row_plus1);
289+
Value bj_pos = rewriter.create<mlir::IndexCastOp>(loc, bj_pos_64, indexType);
290+
291+
if (!needs_col)
292+
{
293+
col_64 = rewriter.create<memref::LoadOp>(loc, Aj, jj);
294+
}
295+
rewriter.create<memref::StoreOp>(loc, col_64, Bj, bj_pos);
296+
297+
if (!needs_val)
298+
{
299+
val = rewriter.create<memref::LoadOp>(loc, Ax, jj);
300+
}
301+
rewriter.create<memref::StoreOp>(loc, val, Bx, bj_pos);
302+
303+
Value bj_pos_plus1 = rewriter.create<mlir::AddIOp>(loc, bj_pos_64, c1_64);
304+
rewriter.create<memref::StoreOp>(loc, bj_pos_plus1, Bp, row_plus1);
305+
306+
rewriter.setInsertionPointAfter(outerLoop);
307+
308+
// trim excess values
309+
Value nnz_64 = rewriter.create<memref::LoadOp>(loc, Bp, nrow);
310+
Value nnz = rewriter.create<mlir::IndexCastOp>(loc, nnz_64, indexType);
311+
312+
callResizeIndex(rewriter, module, loc, output, c1, nnz);
313+
callResizeValues(rewriter, module, loc, output, nnz);
314+
315+
rewriter.replaceOp(op, output);
316+
317+
return success();
187318
};
188319
};
189320

@@ -343,6 +474,7 @@ class LowerMatrixMultiplyRewrite : public OpRewritePattern<graphblas::MatrixMult
343474

344475
void populateGraphBLASLoweringPatterns(RewritePatternSet &patterns) {
345476
patterns.add<
477+
LowerMatrixSelectRewrite,
346478
LowerMatrixReduceToScalarRewrite,
347479
LowerMatrixMultiplyRewrite,
348480
LowerTransposeRewrite
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: graphblas-opt %s | graphblas-opt --graphblas-lower | FileCheck %s
2+
3+
#CSR64 = #sparse_tensor.encoding<{
4+
dimLevelType = [ "dense", "compressed" ],
5+
dimOrdering = affine_map<(i,j) -> (i,j)>,
6+
pointerBitWidth = 64,
7+
indexBitWidth = 64
8+
}>
9+
10+
// CHECK-LABEL: func @select_gt0(
11+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> {
12+
// CHECK: %[[VAL_1:.*]] = constant 0 : index
13+
// CHECK: %[[VAL_2:.*]] = constant 1 : index
14+
// CHECK: %[[VAL_3:.*]] = constant 0 : i64
15+
// CHECK: %[[VAL_4:.*]] = constant 1 : i64
16+
// CHECK: %[[VAL_5:.*]] = constant 0.000000e+00 : f64
17+
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
18+
// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
19+
// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
20+
// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xf64>
21+
// CHECK: %[[VAL_11:.*]] = call @dup_tensor(%[[VAL_0]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
22+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_11]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
23+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_11]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
24+
// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_11]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xf64>
25+
// CHECK: memref.store %[[VAL_3]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref<?xi64>
26+
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_1]] to %[[VAL_6]] step %[[VAL_2]] {
27+
// CHECK: %[[VAL_16:.*]] = addi %[[VAL_15]], %[[VAL_2]] : index
28+
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]]] : memref<?xi64>
29+
// CHECK: memref.store %[[VAL_17]], %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
30+
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xi64>
31+
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xi64>
32+
// CHECK: %[[VAL_20:.*]] = index_cast %[[VAL_18]] : i64 to index
33+
// CHECK: %[[VAL_21:.*]] = index_cast %[[VAL_19]] : i64 to index
34+
// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_2]] {
35+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref<?xf64>
36+
// CHECK: %[[VAL_24:.*]] = cmpf ogt, %[[VAL_23]], %[[VAL_5]] : f64
37+
// CHECK: scf.if %[[VAL_24]] {
38+
// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
39+
// CHECK: %[[VAL_26:.*]] = index_cast %[[VAL_25]] : i64 to index
40+
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xi64>
41+
// CHECK: memref.store %[[VAL_27]], %[[VAL_13]]{{\[}}%[[VAL_26]]] : memref<?xi64>
42+
// CHECK: memref.store %[[VAL_23]], %[[VAL_14]]{{\[}}%[[VAL_26]]] : memref<?xf64>
43+
// CHECK: %[[VAL_28:.*]] = addi %[[VAL_25]], %[[VAL_4]] : i64
44+
// CHECK: memref.store %[[VAL_28]], %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
45+
// CHECK: }
46+
// CHECK: }
47+
// CHECK: }
48+
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<?xi64>
49+
// CHECK: %[[VAL_30:.*]] = index_cast %[[VAL_29]] : i64 to index
50+
// CHECK: call @resize_index(%[[VAL_11]], %[[VAL_2]], %[[VAL_30]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>, index, index) -> ()
51+
// CHECK: call @resize_values(%[[VAL_11]], %[[VAL_30]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>, index) -> ()
52+
// CHECK: return %[[VAL_11]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
53+
// CHECK: }
54+
55+
func @select_gt0(%sparse_tensor: tensor<?x?xf64, #CSR64>) -> tensor<?x?xf64, #CSR64> {
56+
%answer = graphblas.matrix_select %sparse_tensor { selector = "gt0" } : tensor<?x?xf64, #CSR64>
57+
return %answer : tensor<?x?xf64, #CSR64>
58+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: graphblas-opt %s | graphblas-opt --graphblas-lower | FileCheck %s
2+
3+
#CSR64 = #sparse_tensor.encoding<{
4+
dimLevelType = [ "dense", "compressed" ],
5+
dimOrdering = affine_map<(i,j) -> (i,j)>,
6+
pointerBitWidth = 64,
7+
indexBitWidth = 64
8+
}>
9+
10+
// CHECK-LABEL: func @select_tril(
11+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> {
12+
// CHECK: %[[VAL_1:.*]] = constant 0 : index
13+
// CHECK: %[[VAL_2:.*]] = constant 1 : index
14+
// CHECK: %[[VAL_3:.*]] = constant 0 : i64
15+
// CHECK: %[[VAL_4:.*]] = constant 1 : i64
16+
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
17+
// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
18+
// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
19+
// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xf64>
20+
// CHECK: %[[VAL_11:.*]] = call @dup_tensor(%[[VAL_0]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>) -> tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
21+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_11]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
22+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_11]], %[[VAL_2]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xi64>
23+
// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_11]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>> to memref<?xf64>
24+
// CHECK: memref.store %[[VAL_3]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref<?xi64>
25+
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_1]] to %[[VAL_6]] step %[[VAL_2]] {
26+
// CHECK: %[[VAL_16:.*]] = addi %[[VAL_15]], %[[VAL_2]] : index
27+
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]]] : memref<?xi64>
28+
// CHECK: memref.store %[[VAL_17]], %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
29+
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xi64>
30+
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xi64>
31+
// CHECK: %[[VAL_20:.*]] = index_cast %[[VAL_18]] : i64 to index
32+
// CHECK: %[[VAL_21:.*]] = index_cast %[[VAL_19]] : i64 to index
33+
// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_2]] {
34+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xi64>
35+
// CHECK: %[[VAL_24:.*]] = index_cast %[[VAL_23]] : i64 to index
36+
// CHECK: %[[VAL_25:.*]] = cmpi ult, %[[VAL_24]], %[[VAL_15]] : index
37+
// CHECK: scf.if %[[VAL_25]] {
38+
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
39+
// CHECK: %[[VAL_27:.*]] = index_cast %[[VAL_26]] : i64 to index
40+
// CHECK: memref.store %[[VAL_23]], %[[VAL_13]]{{\[}}%[[VAL_27]]] : memref<?xi64>
41+
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref<?xf64>
42+
// CHECK: memref.store %[[VAL_28]], %[[VAL_14]]{{\[}}%[[VAL_27]]] : memref<?xf64>
43+
// CHECK: %[[VAL_29:.*]] = addi %[[VAL_26]], %[[VAL_4]] : i64
44+
// CHECK: memref.store %[[VAL_29]], %[[VAL_12]]{{\[}}%[[VAL_16]]] : memref<?xi64>
45+
// CHECK: }
46+
// CHECK: }
47+
// CHECK: }
48+
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<?xi64>
49+
// CHECK: %[[VAL_31:.*]] = index_cast %[[VAL_30]] : i64 to index
50+
// CHECK: call @resize_index(%[[VAL_11]], %[[VAL_2]], %[[VAL_31]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>, index, index) -> ()
51+
// CHECK: call @resize_values(%[[VAL_11]], %[[VAL_31]]) : (tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>, index) -> ()
52+
// CHECK: return %[[VAL_11]] : tensor<?x?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, pointerBitWidth = 64, indexBitWidth = 64 }>>
53+
// CHECK: }
54+
55+
func @select_tril(%sparse_tensor: tensor<?x?xf64, #CSR64>) -> tensor<?x?xf64, #CSR64> {
56+
%answer = graphblas.matrix_select %sparse_tensor { selector = "tril" } : tensor<?x?xf64, #CSR64>
57+
return %answer : tensor<?x?xf64, #CSR64>
58+
}

0 commit comments

Comments
 (0)