diff --git a/src/enzyme_ad/jax/Utils.cpp b/src/enzyme_ad/jax/Utils.cpp index d3b208731c..79a22d4b02 100644 --- a/src/enzyme_ad/jax/Utils.cpp +++ b/src/enzyme_ad/jax/Utils.cpp @@ -606,6 +606,7 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed( assert(op); auto outTy = cast(op->getResult(0).getType()); + if (outTy.getRank() != 2) return State::NOTGUARANTEED; // this pass only checks for symmetric matrices if (outTy.getDimSize(0) != outTy.getDimSize(1)) @@ -667,6 +668,27 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed( recursiveCheck = true; } + // propagate symmetry for transpose + if (isa(op)) { + recursiveCheck = true; + } + + // propagate symmetry for A * A + if (auto dotGeneralOp = dyn_cast(op)) { + auto lhs = dotGeneralOp.getOperand(0); + auto rhs = dotGeneralOp.getOperand(1); + auto dimensionNumbers = dotGeneralOp.getDotDimensionNumbers(); + + if (lhs == rhs) { + auto lhs_contracting = dimensionNumbers.getLhsContractingDimensions(); + auto rhs_contracting = dimensionNumbers.getRhsContractingDimensions(); + + if (lhs_contracting.size() == 1 && rhs_contracting.size() == 1) { + recursiveCheck = true; + } + } + } + /** * TODO * - check if its * 0 -> symmetric @@ -739,13 +761,6 @@ NoNanResultAnalysis::localGuaranteed(Operation *op, PatternRewriter &rewriter) { assert(op); - if (auto boolAttr = op->getAttrOfType(getAttrName())) { - if (boolAttr.getValue()) - return State::GUARANTEED; - else - return State::NOTGUARANTEED; - } - DenseElementsAttr denseAttr; if (matchPattern(op, m_Constant(&denseAttr))) { if (guaranteedConstantOp(op, denseAttr, rewriter)) { @@ -882,13 +897,6 @@ FiniteResultAnalysis::localGuaranteed(Operation *op, PatternRewriter &rewriter) { assert(op); - if (auto boolAttr = op->getAttrOfType(getAttrName())) { - if (boolAttr.getValue()) - return State::GUARANTEED; - else - return State::NOTGUARANTEED; - } - DenseElementsAttr denseAttr; if (matchPattern(op, m_Constant(&denseAttr))) { if (guaranteedConstantOp(op, denseAttr, rewriter)) { @@ -995,13 +1003,6 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed( PatternRewriter &rewriter) { assert(op); - if (auto boolAttr = op->getAttrOfType(getAttrName())) { - if (boolAttr.getValue()) - return State::GUARANTEED; - else - return State::NOTGUARANTEED; - } - DenseElementsAttr denseAttr; if (matchPattern(op, m_Constant(&denseAttr))) { if (guaranteedConstantOp(op, denseAttr, rewriter)) { diff --git a/src/enzyme_ad/jax/Utils.h b/src/enzyme_ad/jax/Utils.h index 72a52e9488..0c9b2d3864 100644 --- a/src/enzyme_ad/jax/Utils.h +++ b/src/enzyme_ad/jax/Utils.h @@ -541,8 +541,18 @@ template class GuaranteedResultAnalysisBase { State localGuaranteedWithSetAttr(Operation *op, SmallVectorImpl &localtodo, PatternRewriter &rewriter) { - auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter); + auto attrName = ((Child *)this)->getAttrName(); + + if (auto boolAttr = op->getAttrOfType(attrName)) { + if (boolAttr.getValue()) + return State::GUARANTEED; + else + return State::NOTGUARANTEED; + } + + auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter); + switch (state) { case State::GUARANTEED: rewriter.modifyOpInPlace(op, [&]() { diff --git a/test/lit_tests/structured_tensors/propagate_symmetric.mlir b/test/lit_tests/structured_tensors/propagate_symmetric.mlir new file mode 100644 index 0000000000..87efbd1081 --- /dev/null +++ b/test/lit_tests/structured_tensors/propagate_symmetric.mlir @@ -0,0 +1,16 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_symmetric_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = stablehlo.reshape %arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = stablehlo.subtract %1, %0 : tensor<2x2xf32> + %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %3 : tensor<2x2xf32> +} + +// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK-NEXT: %0 = stablehlo.reshape %arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 {enzymexla.guaranteed_symmetric = true} : tensor<2x2xf32> +// CHECK-NEXT: return %2 : tensor<2x2xf32> +// CHECK-NEXT: }