Skip to content

Commit 22d7846

Browse files
authored
Merge pull request #101 from seibert/matrix_reduce_extensions
convert matrix_reduce_to_scalar and matrix_multiply_reduce_to_scalar to structural form
2 parents 513a74a + 5879928 commit 22d7846

11 files changed

+548
-234
lines changed

mlir_graphblas/src/include/GraphBLAS/GraphBLASOps.td

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,32 @@ def GraphBLAS_MatrixReduceToScalarOp : GraphBLAS_Op<"matrix_reduce_to_scalar", [
185185
}];
186186
}
187187

188+
def GraphBLAS_MatrixReduceToScalarGenericOp : GraphBLAS_Op<"matrix_reduce_to_scalar_generic", [NoSideEffect]> {
189+
let summary = "matrix reduce to scalar generic operation";
190+
let description = [{
191+
Reduces a sparse tensor to a scalar according to the given aggregator block.
192+
The given sparse tensor must be a matrix, i.e. have rank 2.
193+
The given tensor must have a CSR sparsity or a CSC sparsity.
194+
The resulting scalar's type will depend on the type of the input tensor.
195+
196+
Example:
197+
```%answer = graphblas.matrix_reduce_to_scalar_generic %sparse_tensor : tensor<?x?xi64, #CSR64> to i64 {
198+
^bb0(%a : i64, %b : i64):
199+
%result = std.addi %a, %b : i64
200+
graphblas.yield agg %result : i64
201+
}
202+
```
203+
}];
204+
205+
let arguments = (ins AnyTensor:$input);
206+
let results = (outs AnyType:$output);
207+
let regions = (region VariadicRegion<SizedRegion<1>>:$extensions);
208+
209+
let assemblyFormat = [{
210+
$input attr-dict `:` type($input) `to` type($output) $extensions
211+
}];
212+
}
213+
188214
def GraphBLAS_MatrixApplyOp : GraphBLAS_Op<"matrix_apply", [NoSideEffect]> {
189215
let summary = "matrix apply operation";
190216
let description = [{
@@ -284,31 +310,24 @@ def GraphBLAS_MatrixMultiplyGenericOp : GraphBLAS_Op<"matrix_multiply_generic",
284310
}];
285311
}
286312

287-
def GraphBLAS_MatrixMultiplyReduceToScalarOp : GraphBLAS_Op<"matrix_multiply_reduce_to_scalar", [NoSideEffect]> {
313+
def GraphBLAS_MatrixMultiplyReduceToScalarGenericOp : GraphBLAS_Op<"matrix_multiply_reduce_to_scalar_generic", [NoSideEffect]> {
288314
let summary = "matrix multiply followed by reduction to a scalar with an optional structural mask";
289315
let description = [{
290316
Performs a matrix multiply followed by a reduction to scalar.
291-
The multiplication is done according to the given semiring and optional structural mask.
292-
The semiring must be one of "plus_times", "plus_pair", or "plus_plus".
293-
The reduction to scalar is done according to the given aggregator.
294-
The aggregator must be "sum".
317+
Supports same extension blocks as matrix_multiply_generic, and also requires binary aggregation
318+
block (aggregation assumes same identity as semiring add).
319+
295320
The given sparse tensors must be a matrix, i.e. have rank 2.
296321
The first input tensors must be CSR format, while the second input tensor must be CSC format.
297322
The mask (if provided) must be CSR format.
298-
299-
No Mask Example:
300-
```%answer = graphblas.matrix_multiply_reduce_to_scalar %argA, %argB { semiring = "plus_plus", aggregator = "sum" } : (tensor<?x?xi64, #CSR64>, tensor<?x?xi64, #CSC64>) to f64```
301-
302-
Mask Example:
303-
```%answer = graphblas.matrix_multiply_reduce_to_scalar %argA, %argB, %mask { semiring = "plus_times", aggregator = "sum" } : (tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSC64>, tensor<?x?xf64, #CSR64>) to f64```
304-
305323
}];
306324

307-
let arguments = (ins AnyTensor:$a, AnyTensor:$b, Optional<AnyTensor>:$mask, StrAttr:$semiring, StrAttr:$aggregator);
325+
let arguments = (ins AnyTensor:$a, AnyTensor:$b, Optional<AnyTensor>:$mask);
308326
let results = (outs AnyType:$output);
327+
let regions = (region VariadicRegion<SizedRegion<1>>:$extensions);
309328

310329
let assemblyFormat = [{
311-
$a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output)
330+
$a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output) $extensions
312331
}];
313332
}
314333

@@ -367,13 +386,16 @@ def YIELD_ADD_IDENTITY : I64EnumAttrCase<"ADD_IDENTITY", 6, "add_identity">;
367386
def YIELD_ADD : I64EnumAttrCase<"ADD", 7, "add">;
368387
def YIELD_MULT_IDENTITY : I64EnumAttrCase<"MULT_IDENTITY", 8, "mult_identity">;
369388
def YIELD_MULT : I64EnumAttrCase<"MULT", 9, "mult">;
389+
def YIELD_AGG_IDENTITY : I64EnumAttrCase<"AGG_IDENTITY", 10, "agg_identity">;
390+
def YIELD_AGG : I64EnumAttrCase<"AGG", 11, "agg">;
370391

371392
def YieldKindAttr : I64EnumAttr<
372393
"YieldKind", "",
373394
[YIELD_TRANSFORM_IN_A, YIELD_TRANSFORM_IN_B, YIELD_TRANSFORM_OUT,
374395
YIELD_SELECT_IN_A, YIELD_SELECT_IN_B, YIELD_SELECT_OUT,
375396
YIELD_ADD_IDENTITY, YIELD_ADD,
376-
YIELD_MULT_IDENTITY, YIELD_MULT]
397+
YIELD_MULT_IDENTITY, YIELD_MULT,
398+
YIELD_AGG_IDENTITY, YIELD_AGG]
377399
> {
378400
let cppNamespace = "::mlir::graphblas";
379401
}

mlir_graphblas/src/include/GraphBLAS/GraphBLASUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ bool typeIsCSC(mlir::Type inputType);
1414
mlir::RankedTensorType getCSRTensorType(mlir::MLIRContext *context, llvm::ArrayRef<int64_t> shape, mlir::Type valueType);
1515
mlir::RankedTensorType getCSCTensorType(mlir::MLIRContext *context, llvm::ArrayRef<int64_t> shape, mlir::Type valueType);
1616

17+
int64_t getRank(mlir::Type inputType);
18+
int64_t getRank(mlir::Value inputValue);
19+
1720
mlir::Value convertToExternalCSR(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value input);
1821
mlir::Value convertToExternalCSC(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value input);
1922
mlir::Value callEmptyLike(mlir::OpBuilder &builder, mlir::ModuleOp &mod, mlir::Location loc, mlir::Value tensor);
@@ -42,6 +45,8 @@ struct ExtensionBlocks {
4245
mlir::Block *add = nullptr;
4346
mlir::Block *multIdentity = nullptr;
4447
mlir::Block *mult = nullptr;
48+
mlir::Block *aggIdentity = nullptr;
49+
mlir::Block *agg = nullptr;
4550

4651
ExtensionBlocks() { };
4752
mlir::LogicalResult extractBlocks(mlir::Operation *op, mlir::RegionRange &regions,

0 commit comments

Comments
 (0)