@@ -182,8 +182,139 @@ class LowerMatrixSelectRewrite : public OpRewritePattern<graphblas::MatrixSelect
182182public:
183183 using OpRewritePattern<graphblas::MatrixSelectOp>::OpRewritePattern;
184184 LogicalResult matchAndRewrite (graphblas::MatrixSelectOp op, PatternRewriter &rewriter) const {
185- // TODO sanity check that the sparse encoding is sane
186- return failure ();
185+ ModuleOp module = op->getParentOfType <ModuleOp>();
186+ Location loc = op->getLoc ();
187+
188+ Value input = op.input ();
189+ Type valueType = input.getType ().dyn_cast <TensorType>().getElementType ();
190+ Type int64Type = rewriter.getIntegerType (64 );
191+ FloatType float64Type = rewriter.getF64Type ();
192+ Type indexType = rewriter.getIndexType ();
193+ Type memref1DI64Type = MemRefType::get ({-1 }, int64Type);
194+ Type memref1DValueType = MemRefType::get ({-1 }, valueType);
195+
196+ StringRef selector = op.selector ();
197+
198+ bool needs_col = false , needs_val = false ;
199+ if (selector == " triu" )
200+ {
201+ needs_col = true ;
202+ needs_val = false ;
203+ }
204+ else if (selector == " tril" )
205+ {
206+ needs_col = true ;
207+ needs_val = false ;
208+ }
209+ else if (selector == " gt0" )
210+ {
211+ needs_col = false ;
212+ needs_val = true ;
213+ }
214+ else
215+ {
216+ return failure ();
217+ }
218+
219+ // Initial constants
220+ Value c0 = rewriter.create <ConstantIndexOp>(loc, 0 );
221+ Value c1 = rewriter.create <ConstantIndexOp>(loc, 1 );
222+ Value c0_64 = rewriter.create <ConstantIntOp>(loc, 0 , int64Type);
223+ Value c1_64 = rewriter.create <ConstantIntOp>(loc, 1 , int64Type);
224+ Value cf0 = rewriter.create <ConstantFloatOp>(loc, APFloat (0.0 ), float64Type);
225+
226+ // Get sparse tensor info
227+ Value nrow = rewriter.create <memref::DimOp>(loc, input, c0);
228+ Value Ap = rewriter.create <sparse_tensor::ToPointersOp>(loc, memref1DI64Type, input, c1);
229+ Value Aj = rewriter.create <sparse_tensor::ToIndicesOp>(loc, memref1DI64Type, input, c1);
230+ Value Ax = rewriter.create <sparse_tensor::ToValuesOp>(loc, memref1DValueType, input);
231+
232+ Value output = callDupTensor (rewriter, module , loc, input).getResult (0 );
233+ Value Bp = rewriter.create <sparse_tensor::ToPointersOp>(loc, memref1DI64Type, output, c1);
234+ Value Bj = rewriter.create <sparse_tensor::ToIndicesOp>(loc, memref1DI64Type, output, c1);
235+ Value Bx = rewriter.create <sparse_tensor::ToValuesOp>(loc, memref1DValueType, output);
236+
237+ rewriter.create <memref::StoreOp>(loc, c0_64, Bp, c0);
238+ // Loop
239+ scf::ForOp outerLoop = rewriter.create <scf::ForOp>(loc, c0, nrow, c1);
240+ Value row = outerLoop.getInductionVar ();
241+
242+ rewriter.setInsertionPointToStart (outerLoop.getBody ());
243+ Value row_plus1 = rewriter.create <mlir::AddIOp>(loc, row, c1);
244+ Value bp_curr_count = rewriter.create <memref::LoadOp>(loc, Bp, row);
245+ rewriter.create <memref::StoreOp>(loc, bp_curr_count, Bp, row_plus1);
246+
247+ Value j_start_64 = rewriter.create <memref::LoadOp>(loc, Ap, row);
248+ Value j_end_64 = rewriter.create <memref::LoadOp>(loc, Ap, row_plus1);
249+ Value j_start = rewriter.create <mlir::IndexCastOp>(loc, j_start_64, indexType);
250+ Value j_end = rewriter.create <mlir::IndexCastOp>(loc, j_end_64, indexType);
251+
252+ scf::ForOp innerLoop = rewriter.create <scf::ForOp>(loc, j_start, j_end, c1);
253+
254+ Value jj = innerLoop.getInductionVar ();
255+
256+ rewriter.setInsertionPointToStart (innerLoop.getBody ());
257+ Value col_64, col, val, keep;
258+ if (needs_col)
259+ {
260+ col_64 = rewriter.create <memref::LoadOp>(loc, Aj, jj);
261+ col = rewriter.create <mlir::IndexCastOp>(loc, col_64, indexType);
262+ }
263+ if (needs_val)
264+ {
265+ val = rewriter.create <memref::LoadOp>(loc, Ax, jj);
266+ }
267+ if (selector == " triu" )
268+ {
269+ keep = rewriter.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::ugt, col, row);
270+ }
271+ else if (selector == " tril" )
272+ {
273+ keep = rewriter.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::ult, col, row);
274+ }
275+ else if (selector == " gt0" )
276+ {
277+ keep = rewriter.create <mlir::CmpFOp>(loc, mlir::CmpFPredicate::OGT, val, cf0);
278+ }
279+ else
280+ {
281+ return failure ();
282+ }
283+
284+ scf::IfOp ifKeep = rewriter.create <scf::IfOp>(loc, keep, false /* no else region */ );
285+
286+ rewriter.setInsertionPointToStart (ifKeep.thenBlock ());
287+
288+ Value bj_pos_64 = rewriter.create <memref::LoadOp>(loc, Bp, row_plus1);
289+ Value bj_pos = rewriter.create <mlir::IndexCastOp>(loc, bj_pos_64, indexType);
290+
291+ if (!needs_col)
292+ {
293+ col_64 = rewriter.create <memref::LoadOp>(loc, Aj, jj);
294+ }
295+ rewriter.create <memref::StoreOp>(loc, col_64, Bj, bj_pos);
296+
297+ if (!needs_val)
298+ {
299+ val = rewriter.create <memref::LoadOp>(loc, Ax, jj);
300+ }
301+ rewriter.create <memref::StoreOp>(loc, val, Bx, bj_pos);
302+
303+ Value bj_pos_plus1 = rewriter.create <mlir::AddIOp>(loc, bj_pos_64, c1_64);
304+ rewriter.create <memref::StoreOp>(loc, bj_pos_plus1, Bp, row_plus1);
305+
306+ rewriter.setInsertionPointAfter (outerLoop);
307+
308+ // trim excess values
309+ Value nnz_64 = rewriter.create <memref::LoadOp>(loc, Bp, nrow);
310+ Value nnz = rewriter.create <mlir::IndexCastOp>(loc, nnz_64, indexType);
311+
312+ callResizeIndex (rewriter, module , loc, output, c1, nnz);
313+ callResizeValues (rewriter, module , loc, output, nnz);
314+
315+ rewriter.replaceOp (op, output);
316+
317+ return success ();
187318 };
188319};
189320
@@ -343,6 +474,7 @@ class LowerMatrixMultiplyRewrite : public OpRewritePattern<graphblas::MatrixMult
343474
344475void populateGraphBLASLoweringPatterns (RewritePatternSet &patterns) {
345476 patterns.add <
477+ LowerMatrixSelectRewrite,
346478 LowerMatrixReduceToScalarRewrite,
347479 LowerMatrixMultiplyRewrite,
348480 LowerTransposeRewrite
0 commit comments