@@ -185,6 +185,32 @@ def GraphBLAS_MatrixReduceToScalarOp : GraphBLAS_Op<"matrix_reduce_to_scalar", [
185185 }];
186186}
187187
188+ def GraphBLAS_MatrixReduceToScalarGenericOp : GraphBLAS_Op<"matrix_reduce_to_scalar_generic", [NoSideEffect]> {
189+ let summary = "matrix reduce to scalar generic operation";
190+ let description = [{
191+ Reduces a sparse tensor to a scalar according to the given aggregator block.
192+ The given sparse tensor must be a matrix, i.e. have rank 2.
193+ The given tensor must have a CSR sparsity or a CSC sparsity.
194+ The resulting scalar's type will depend on the type of the input tensor.
195+
196+ Example:
197+ ```%answer = graphblas.matrix_reduce_to_scalar_generic %sparse_tensor : tensor<?x?xi64, #CSR64> to i64 {
198+ ^bb0(%a : i64, %b : i64):
199+ %result = std.addi %a, %b : i64
200+ graphblas.yield agg %result : i64
201+ }
202+ ```
203+ }];
204+
205+ let arguments = (ins AnyTensor:$input);
206+ let results = (outs AnyType:$output);
207+ let regions = (region VariadicRegion<SizedRegion<1>>:$extensions);
208+
209+ let assemblyFormat = [{
210+ $input attr-dict `:` type($input) `to` type($output) $extensions
211+ }];
212+ }
213+
188214def GraphBLAS_MatrixApplyOp : GraphBLAS_Op<"matrix_apply", [NoSideEffect]> {
189215 let summary = "matrix apply operation";
190216 let description = [{
@@ -284,31 +310,24 @@ def GraphBLAS_MatrixMultiplyGenericOp : GraphBLAS_Op<"matrix_multiply_generic",
284310 }];
285311}
286312
287- def GraphBLAS_MatrixMultiplyReduceToScalarOp : GraphBLAS_Op<"matrix_multiply_reduce_to_scalar ", [NoSideEffect]> {
313+ def GraphBLAS_MatrixMultiplyReduceToScalarGenericOp : GraphBLAS_Op<"matrix_multiply_reduce_to_scalar_generic ", [NoSideEffect]> {
288314 let summary = "matrix multiply followed by reduction to a scalar with an optional structural mask";
289315 let description = [{
290316 Performs a matrix multiply followed by a reduction to scalar.
291- The multiplication is done according to the given semiring and optional structural mask.
292- The semiring must be one of "plus_times", "plus_pair", or "plus_plus".
293- The reduction to scalar is done according to the given aggregator.
294- The aggregator must be "sum".
317+ Supports same extension blocks as matrix_multiply_generic, and also requires binary aggregation
318+ block (aggregation assumes same identity as semiring add).
319+
295320 The given sparse tensors must be a matrix, i.e. have rank 2.
296321 The first input tensors must be CSR format, while the second input tensor must be CSC format.
297322 The mask (if provided) must be CSR format.
298-
299- No Mask Example:
300- ```%answer = graphblas.matrix_multiply_reduce_to_scalar %argA, %argB { semiring = "plus_plus", aggregator = "sum" } : (tensor<?x?xi64, #CSR64>, tensor<?x?xi64, #CSC64>) to f64```
301-
302- Mask Example:
303- ```%answer = graphblas.matrix_multiply_reduce_to_scalar %argA, %argB, %mask { semiring = "plus_times", aggregator = "sum" } : (tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSC64>, tensor<?x?xf64, #CSR64>) to f64```
304-
305323 }];
306324
307- let arguments = (ins AnyTensor:$a, AnyTensor:$b, Optional<AnyTensor>:$mask, StrAttr:$semiring, StrAttr:$aggregator );
325+ let arguments = (ins AnyTensor:$a, AnyTensor:$b, Optional<AnyTensor>:$mask);
308326 let results = (outs AnyType:$output);
327+ let regions = (region VariadicRegion<SizedRegion<1>>:$extensions);
309328
310329 let assemblyFormat = [{
311- $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output)
330+ $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output) $extensions
312331 }];
313332}
314333
@@ -367,13 +386,16 @@ def YIELD_ADD_IDENTITY : I64EnumAttrCase<"ADD_IDENTITY", 6, "add_identity">;
367386def YIELD_ADD : I64EnumAttrCase<"ADD", 7, "add">;
368387def YIELD_MULT_IDENTITY : I64EnumAttrCase<"MULT_IDENTITY", 8, "mult_identity">;
369388def YIELD_MULT : I64EnumAttrCase<"MULT", 9, "mult">;
389+ def YIELD_AGG_IDENTITY : I64EnumAttrCase<"AGG_IDENTITY", 10, "agg_identity">;
390+ def YIELD_AGG : I64EnumAttrCase<"AGG", 11, "agg">;
370391
371392def YieldKindAttr : I64EnumAttr<
372393 "YieldKind", "",
373394 [YIELD_TRANSFORM_IN_A, YIELD_TRANSFORM_IN_B, YIELD_TRANSFORM_OUT,
374395 YIELD_SELECT_IN_A, YIELD_SELECT_IN_B, YIELD_SELECT_OUT,
375396 YIELD_ADD_IDENTITY, YIELD_ADD,
376- YIELD_MULT_IDENTITY, YIELD_MULT]
397+ YIELD_MULT_IDENTITY, YIELD_MULT,
398+ YIELD_AGG_IDENTITY, YIELD_AGG]
377399 > {
378400 let cppNamespace = "::mlir::graphblas";
379401}
0 commit comments