Skip to content

Commit 3da6a56

Browse files
committed
feat: cleaner printing
1 parent 4eb0db6 commit 3da6a56

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class StructuredSparsityPattern {
6767
refineKind();
6868
}
6969

70+
StructuredSparsityKind getKind() const { return kind; }
7071
int64_t getLowerBandwidth() const { return lowerBandwidth; }
7172
int64_t getUpperBandwidth() const { return upperBandwidth; }
7273

src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,31 +70,71 @@ class StructuredMatrixSimplifyPass
7070
continue;
7171
}
7272

73-
anyKnown = true;
73+
if (state->getValue().getSparsityPattern().getKind() !=
74+
mlir::structure_analysis::StructuredSparsityKind::Unknown) {
75+
anyKnown = true;
76+
}
77+
78+
enzymexla::StructuredSparsityKind ssKind;
79+
switch (state->getValue().getSparsityPattern().getKind()) {
80+
case mlir::structure_analysis::StructuredSparsityKind::Unknown:
81+
ssKind = enzymexla::StructuredSparsityKind::Unknown;
82+
break;
83+
case mlir::structure_analysis::StructuredSparsityKind::Dense:
84+
ssKind = enzymexla::StructuredSparsityKind::Dense;
85+
break;
86+
case mlir::structure_analysis::StructuredSparsityKind::Band:
87+
ssKind = enzymexla::StructuredSparsityKind::Band;
88+
break;
89+
case mlir::structure_analysis::StructuredSparsityKind::UpperTriangular:
90+
ssKind = enzymexla::StructuredSparsityKind::UpperTriangular;
91+
break;
92+
case mlir::structure_analysis::StructuredSparsityKind::UpperBidiagonal:
93+
ssKind = enzymexla::StructuredSparsityKind::UpperBidiagonal;
94+
break;
95+
case mlir::structure_analysis::StructuredSparsityKind::LowerTriangular:
96+
ssKind = enzymexla::StructuredSparsityKind::LowerTriangular;
97+
break;
98+
case mlir::structure_analysis::StructuredSparsityKind::LowerBidiagonal:
99+
ssKind = enzymexla::StructuredSparsityKind::LowerBidiagonal;
100+
break;
101+
case mlir::structure_analysis::StructuredSparsityKind::Tridiagonal:
102+
ssKind = enzymexla::StructuredSparsityKind::Tridiagonal;
103+
break;
104+
case mlir::structure_analysis::StructuredSparsityKind::Diagonal:
105+
ssKind = enzymexla::StructuredSparsityKind::Diagonal;
106+
break;
107+
case mlir::structure_analysis::StructuredSparsityKind::Empty:
108+
ssKind = enzymexla::StructuredSparsityKind::Empty;
109+
break;
110+
}
74111

75-
// TODO: get structured sparsity kind
76112
auto structuredSparsityKind =
77113
enzymexla::StructuredSparsityPatternAttr::get(
78-
mod.getContext(), enzymexla::StructuredSparsityKind::Unknown,
114+
mod.getContext(), ssKind,
79115
state->getValue().getSparsityPattern().getLowerBandwidth(),
80116
state->getValue().getSparsityPattern().getUpperBandwidth());
81117

82118
SmallVector<enzymexla::StructuredValueProperty>
83119
structuredValueProperties;
84120
auto valueProperties = state->getValue().getProperties();
85121
if (valueProperties.hasUnitDiagonal()) {
122+
anyKnown = true;
86123
structuredValueProperties.push_back(
87124
enzymexla::StructuredValueProperty::UnitDiagonal);
88125
}
89126
if (valueProperties.isSymmetric()) {
127+
anyKnown = true;
90128
structuredValueProperties.push_back(
91129
enzymexla::StructuredValueProperty::Symmetric);
92130
}
93131
if (valueProperties.isHermitian()) {
132+
anyKnown = true;
94133
structuredValueProperties.push_back(
95134
enzymexla::StructuredValueProperty::Hermitian);
96135
}
97136
if (valueProperties.isBroadcastedScalar()) {
137+
anyKnown = true;
98138
structuredValueProperties.push_back(
99139
enzymexla::StructuredValueProperty::BroadcastedScalar);
100140
}

0 commit comments

Comments
 (0)