Skip to content

Commit 5879928

Browse files
committed
incorporate code review feedback
1 parent 3463a36 commit 5879928

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

mlir_graphblas/src/include/GraphBLAS/GraphBLASUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ bool typeIsCSC(mlir::Type inputType);
1414
mlir::RankedTensorType getCSRTensorType(mlir::MLIRContext *context, llvm::ArrayRef<int64_t> shape, mlir::Type valueType);
1515
mlir::RankedTensorType getCSCTensorType(mlir::MLIRContext *context, llvm::ArrayRef<int64_t> shape, mlir::Type valueType);
1616

17+
int64_t getRank(mlir::Type inputType);
18+
int64_t getRank(mlir::Value inputValue);
19+
1720
mlir::Value convertToExternalCSR(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value input);
1821
mlir::Value convertToExternalCSC(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value input);
1922
mlir::Value callEmptyLike(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value tensor);

mlir_graphblas/src/lib/GraphBLAS/GraphBLASOps.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "llvm/ADT/None.h"
1414

1515
#include "GraphBLAS/GraphBLASOpsEnums.cpp.inc"
16+
#include "GraphBLAS/GraphBLASUtils.h"
1617

1718
using namespace mlir;
1819
using namespace mlir::graphblas;
@@ -99,17 +100,6 @@ static llvm::Optional<std::string> checkCompressedVector(
99100
return llvm::None;
100101
}
101102

102-
static int64_t getRank(Type inputType)
103-
{
104-
mlir::sparse_tensor::SparseTensorEncodingAttr sparseEncoding =
105-
mlir::sparse_tensor::getSparseTensorEncoding(inputType);
106-
if (!sparseEncoding)
107-
return -1;
108-
109-
RankedTensorType inputTensorType = inputType.dyn_cast<RankedTensorType>();
110-
return inputTensorType.getRank();
111-
}
112-
113103
//===--------------------------------------------------------------------===//
114104
// GraphBLAS Ops Methods
115105
//===--------------------------------------------------------------------===//
@@ -310,7 +300,6 @@ static LogicalResult verifyMatrixMultiplyArgs(T op, bool checkResultTensorType)
310300
resultTensorType = resultType.dyn_cast<RankedTensorType>();
311301
resultShape = resultTensorType.getShape();
312302
resultRank = getRank(resultType);
313-
RankedTensorType resultTensorType = resultType.dyn_cast<RankedTensorType>();
314303
resultElementType = resultTensorType.getElementType();
315304
}
316305
}

mlir_graphblas/src/lib/GraphBLAS/GraphBLASOptimizePass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ class FuseMatrixMultiplyReduceRewrite : public OpRewritePattern<graphblas::Matri
9797
if (predecessor != nullptr && predecessor->hasOneUse()) {
9898
Location loc = op->getLoc();
9999

100+
if (getRank(predecessor.a()) < 2 || getRank(predecessor.b()) < 2)
101+
return failure();
102+
100103
// Build new MatrixMultiplyReduceToScalarGeneric op with the operands and regions of the multiply,
101104
// then add in the aggregator from the reduce
102105
ValueRange operands = predecessor.getOperands();
@@ -190,7 +193,6 @@ class FuseMatrixMultiplyApplyRewrite : public OpRewritePattern<graphblas::Matrix
190193
RegionRange applyExtensions = op.extensions();
191194

192195
RankedTensorType tensorType = predecessor.a().getType().dyn_cast<RankedTensorType>();
193-
Type valueType = tensorType.getElementType();
194196

195197
graphblas::MatrixMultiplyGenericOp newMultOp = rewriter.create<graphblas::MatrixMultiplyGenericOp>(loc,
196198
op->getResultTypes(), operands, attributes.getAttrs(),

mlir_graphblas/src/lib/GraphBLAS/GraphBLASUtils.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,23 @@ bool typeIsCSC(Type inputType) {
9292
return true;
9393
}
9494

95+
int64_t getRank(Type inputType)
96+
{
97+
mlir::sparse_tensor::SparseTensorEncodingAttr sparseEncoding =
98+
mlir::sparse_tensor::getSparseTensorEncoding(inputType);
99+
if (!sparseEncoding)
100+
return -1;
101+
102+
RankedTensorType inputTensorType = inputType.dyn_cast<RankedTensorType>();
103+
return inputTensorType.getRank();
104+
}
105+
106+
int64_t getRank(Value inputValue)
107+
{
108+
Type inputType = inputValue.getType();
109+
return getRank(inputType);
110+
}
111+
95112
// make Compressed Vector type
96113
RankedTensorType getCompressedVectorType(MLIRContext *context, ArrayRef<int64_t> shape, Type valueType)
97114
{

0 commit comments

Comments
 (0)