33#include " mlir/IR/BuiltinOps.h"
44#include " mlir/IR/BuiltinTypes.h"
55#include " mlir/Dialect/StandardOps/IR/Ops.h"
6+ #include " mlir/Dialect/SCF/SCF.h"
67#include " mlir/IR/MLIRContext.h"
78#include " mlir/IR/Verifier.h"
89#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -18,10 +19,32 @@ void addMatrixSelectFunc(mlir::ModuleOp mod, const std::string &selector)
1819{
1920 MLIRContext *context = mod.getContext ();
2021 OpBuilder builder (mod.getBodyRegion ());
22+ auto loc = builder.getUnknownLoc ();
2123 builder.setInsertionPointToStart (mod.getBody ());
2224
23- // Create function signature
25+ bool needs_col = false , needs_val = false ;
26+ if (selector == " triu" ) {
27+ needs_col = true ;
28+ needs_val = false ;
29+ } else if (selector == " tril" ) {
30+ needs_col = true ;
31+ needs_val = false ;
32+ } else if (selector == " gt0" ) {
33+ needs_col = false ;
34+ needs_val = true ;
35+ } else {
36+ assert (!" invalid selector" );
37+ }
38+
39+ // Types
2440 auto valueType = builder.getF64Type ();
41+ auto i64Type = builder.getI64Type ();
42+ auto f64Type = builder.getF64Type ();
43+ auto indexType = builder.getIndexType ();
44+ auto memref1DI64Type = MemRefType::get ({-1 }, i64Type);
45+ auto memref1DValueType = MemRefType::get ({-1 }, valueType);
46+
47+ // Create function signature
2548 RankedTensorType csrTensor = getCSRTensorType (context, valueType);
2649
2750 string func_name = " matrix_select_" + selector;
@@ -33,8 +56,97 @@ void addMatrixSelectFunc(mlir::ModuleOp mod, const std::string &selector)
3356 auto &entry_block = *func.addEntryBlock ();
3457 builder.setInsertionPointToStart (&entry_block);
3558
59+ auto input = entry_block.getArgument (0 );
60+
3661 // add function body ops here
62+ // Initial constants
63+ Value c0 = builder.create <ConstantIndexOp>(loc, 0 );
64+ Value c1 = builder.create <ConstantIndexOp>(loc, 1 );
65+ Value c0_64 = builder.create <ConstantIntOp>(loc, 0 , i64Type);
66+ Value c1_64 = builder.create <ConstantIntOp>(loc, 1 , i64Type);
67+ Value cf0 = builder.create <ConstantFloatOp>(loc, APFloat (0.0 ), f64Type);
68+
69+ // Get sparse tensor info
70+ Value nrow = builder.create <memref::DimOp>(loc, input, c0);
71+ Value ncol = builder.create <memref::DimOp>(loc, input, c1);
72+ Value Ap = builder.create <ToPointersOp>(loc, memref1DI64Type, input, c1);
73+ Value Aj = builder.create <ToIndicesOp>(loc, memref1DI64Type, input, c1);
74+ Value Ax = builder.create <ToValuesOp>(loc, memref1DValueType, input);
75+
76+ Value output = callDupTensor (builder, mod, loc, input).getResult (0 );
77+ Value Bp = builder.create <ToPointersOp>(loc, memref1DI64Type, output, c1);
78+ Value Bj = builder.create <ToIndicesOp>(loc, memref1DI64Type, output, c1);
79+ Value Bx = builder.create <ToValuesOp>(loc, memref1DValueType, output);
80+
81+ builder.create <memref::StoreOp>(loc, c0_64, Bp, c0);
82+ // Loop
83+ auto outerLoop = builder.create <scf::ForOp>(loc, c0, nrow, c1);
84+ Value row = outerLoop.getInductionVar ();
85+
86+ builder.setInsertionPointToStart (outerLoop.getBody ());
87+ Value row_plus1 = builder.create <mlir::AddIOp>(loc, row, c1);
88+ Value bp_curr_count = builder.create <memref::LoadOp>(loc, Bp, row);
89+ builder.create <memref::StoreOp>(loc, bp_curr_count, Bp, row_plus1);
90+
91+ Value j_start_64 = builder.create <memref::LoadOp>(loc, Ap, row);
92+ Value j_end_64 = builder.create <memref::LoadOp>(loc, Ap, row_plus1);
93+ Value j_start = builder.create <mlir::IndexCastOp>(loc, j_start_64, indexType);
94+ Value j_end = builder.create <mlir::IndexCastOp>(loc, j_end_64, indexType);
95+
96+ auto innerLoop = builder.create <scf::ForOp>(loc, j_start, j_end, c1);
97+
98+ Value jj = innerLoop.getInductionVar ();
99+
100+ builder.setInsertionPointToStart (innerLoop.getBody ());
101+ Value col_64, col, val, keep;
102+ if (needs_col) {
103+ col_64 = builder.create <memref::LoadOp>(loc, Aj, jj);
104+ col = builder.create <mlir::IndexCastOp>(loc, col_64, indexType);
105+ }
106+ if (needs_val) {
107+ val = builder.create <memref::LoadOp>(loc, Ax, jj);
108+ }
109+ if (selector == " triu" ) {
110+ keep = builder.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::ugt, col, row);
111+ }
112+ else if (selector == " tril" ) {
113+ keep = builder.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::ult, col, row);
114+ }
115+ else if (selector == " gt0" ) {
116+ keep = builder.create <mlir::CmpFOp>(loc, mlir::CmpFPredicate::OGT, val, cf0);
117+ }
118+ else {
119+ assert (!" invalid selector" );
120+ }
121+
122+ scf::IfOp ifKeep = builder.create <scf::IfOp>(loc, keep, false /* no else region */ );
123+
124+ builder.setInsertionPointToStart (ifKeep.thenBlock ());
125+
126+ Value bj_pos_64 = builder.create <memref::LoadOp>(loc, Bp, row_plus1);
127+ Value bj_pos = builder.create <mlir::IndexCastOp>(loc, bj_pos_64, indexType);
128+
129+ if (!needs_col) {
130+ col_64 = builder.create <memref::LoadOp>(loc, Aj, jj);
131+ }
132+ builder.create <memref::StoreOp>(loc, col_64, Bj, bj_pos);
133+
134+ if (!needs_val) {
135+ val = builder.create <memref::LoadOp>(loc, Ax, jj);
136+ }
137+ builder.create <memref::StoreOp>(loc, val, Bx, bj_pos);
138+
139+ Value bj_pos_plus1 = builder.create <mlir::AddIOp>(loc, bj_pos_64, c1_64);
140+ builder.create <memref::StoreOp>(loc, bj_pos_plus1, Bp, row_plus1);
141+
142+ builder.setInsertionPointAfter (outerLoop);
143+
144+ Value nnz_64 = builder.create <memref::LoadOp>(loc, Bp, nrow);
145+ Value nnz = builder.create <mlir::IndexCastOp>(loc, nnz_64, indexType);
146+
147+ callResizeIndex (builder, mod, loc, output, c1, nnz);
148+ callResizeValues (builder, mod, loc, output, nnz);
37149
38150 // Add return op
39- builder.create <ReturnOp>(builder.getUnknownLoc ());
151+ builder.create <ReturnOp>(builder.getUnknownLoc (), output );
40152}
0 commit comments