Skip to content
Open
43 changes: 22 additions & 21 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
assert(op);

auto outTy = cast<RankedTensorType>(op->getResult(0).getType());

if (outTy.getRank() != 2)
return State::NOTGUARANTEED; // this pass only checks for symmetric matrices
if (outTy.getDimSize(0) != outTy.getDimSize(1))
Expand Down Expand Up @@ -667,6 +668,27 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
recursiveCheck = true;
}

// propagate symmetry for transpose
if (isa<stablehlo::TransposeOp>(op)) {
recursiveCheck = true;
}

// propagate symmetry for A * A
if (auto dotGeneralOp = dyn_cast<stablehlo::DotGeneralOp>(op)) {
auto lhs = dotGeneralOp.getOperand(0);
auto rhs = dotGeneralOp.getOperand(1);
auto dimensionNumbers = dotGeneralOp.getDotDimensionNumbers();

if (lhs == rhs) {
auto lhs_contracting = dimensionNumbers.getLhsContractingDimensions();
auto rhs_contracting = dimensionNumbers.getRhsContractingDimensions();

if (lhs_contracting.size() == 1 && rhs_contracting.size() == 1) {
recursiveCheck = true;
}
}
}

/**
* TODO
* - check if its * 0 -> symmetric
Expand Down Expand Up @@ -739,13 +761,6 @@ NoNanResultAnalysis::localGuaranteed(Operation *op,
PatternRewriter &rewriter) {
assert(op);

if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

DenseElementsAttr denseAttr;
if (matchPattern(op, m_Constant(&denseAttr))) {
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
Expand Down Expand Up @@ -882,13 +897,6 @@ FiniteResultAnalysis::localGuaranteed(Operation *op,
PatternRewriter &rewriter) {
assert(op);

if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

DenseElementsAttr denseAttr;
if (matchPattern(op, m_Constant(&denseAttr))) {
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
Expand Down Expand Up @@ -995,13 +1003,6 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
PatternRewriter &rewriter) {
assert(op);

if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

DenseElementsAttr denseAttr;
if (matchPattern(op, m_Constant(&denseAttr))) {
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
Expand Down
12 changes: 11 additions & 1 deletion src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,18 @@ template <typename Child> class GuaranteedResultAnalysisBase {
State localGuaranteedWithSetAttr(Operation *op,
SmallVectorImpl<Operation *> &localtodo,
PatternRewriter &rewriter) {
auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter);

auto attrName = ((Child *)this)->getAttrName();

if (auto boolAttr = op->getAttrOfType<BoolAttr>(attrName)) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter);

switch (state) {
case State::GUARANTEED:
rewriter.modifyOpInPlace(op, [&]() {
Expand Down
16 changes: 16 additions & 0 deletions test/lit_tests/structured_tensors/propagate_symmetric.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_symmetric_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s

func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = stablehlo.reshape %arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.subtract %1, %0 : tensor<2x2xf32>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %3 : tensor<2x2xf32>
}

// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %0, contracting_dims = [1] x [0] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 {enzymexla.guaranteed_symmetric = true} : tensor<2x2xf32>
// CHECK-NEXT: return %2 : tensor<2x2xf32>
// CHECK-NEXT: }
Loading