@@ -70,31 +70,71 @@ class StructuredMatrixSimplifyPass
7070 continue ;
7171 }
7272
73- anyKnown = true ;
73+ if (state->getValue ().getSparsityPattern ().getKind () !=
74+ mlir::structure_analysis::StructuredSparsityKind::Unknown) {
75+ anyKnown = true ;
76+ }
77+
78+ enzymexla::StructuredSparsityKind ssKind;
79+ switch (state->getValue ().getSparsityPattern ().getKind ()) {
80+ case mlir::structure_analysis::StructuredSparsityKind::Unknown:
81+ ssKind = enzymexla::StructuredSparsityKind::Unknown;
82+ break ;
83+ case mlir::structure_analysis::StructuredSparsityKind::Dense:
84+ ssKind = enzymexla::StructuredSparsityKind::Dense;
85+ break ;
86+ case mlir::structure_analysis::StructuredSparsityKind::Band:
87+ ssKind = enzymexla::StructuredSparsityKind::Band;
88+ break ;
89+ case mlir::structure_analysis::StructuredSparsityKind::UpperTriangular:
90+ ssKind = enzymexla::StructuredSparsityKind::UpperTriangular;
91+ break ;
92+ case mlir::structure_analysis::StructuredSparsityKind::UpperBidiagonal:
93+ ssKind = enzymexla::StructuredSparsityKind::UpperBidiagonal;
94+ break ;
95+ case mlir::structure_analysis::StructuredSparsityKind::LowerTriangular:
96+ ssKind = enzymexla::StructuredSparsityKind::LowerTriangular;
97+ break ;
98+ case mlir::structure_analysis::StructuredSparsityKind::LowerBidiagonal:
99+ ssKind = enzymexla::StructuredSparsityKind::LowerBidiagonal;
100+ break ;
101+ case mlir::structure_analysis::StructuredSparsityKind::Tridiagonal:
102+ ssKind = enzymexla::StructuredSparsityKind::Tridiagonal;
103+ break ;
104+ case mlir::structure_analysis::StructuredSparsityKind::Diagonal:
105+ ssKind = enzymexla::StructuredSparsityKind::Diagonal;
106+ break ;
107+ case mlir::structure_analysis::StructuredSparsityKind::Empty:
108+ ssKind = enzymexla::StructuredSparsityKind::Empty;
109+ break ;
110+ }
74111
75- // TODO: get structured sparsity kind
76112 auto structuredSparsityKind =
77113 enzymexla::StructuredSparsityPatternAttr::get (
78- mod.getContext (), enzymexla::StructuredSparsityKind::Unknown ,
114+ mod.getContext (), ssKind ,
79115 state->getValue ().getSparsityPattern ().getLowerBandwidth (),
80116 state->getValue ().getSparsityPattern ().getUpperBandwidth ());
81117
82118 SmallVector<enzymexla::StructuredValueProperty>
83119 structuredValueProperties;
84120 auto valueProperties = state->getValue ().getProperties ();
85121 if (valueProperties.hasUnitDiagonal ()) {
122+ anyKnown = true ;
86123 structuredValueProperties.push_back (
87124 enzymexla::StructuredValueProperty::UnitDiagonal);
88125 }
89126 if (valueProperties.isSymmetric ()) {
127+ anyKnown = true ;
90128 structuredValueProperties.push_back (
91129 enzymexla::StructuredValueProperty::Symmetric);
92130 }
93131 if (valueProperties.isHermitian ()) {
132+ anyKnown = true ;
94133 structuredValueProperties.push_back (
95134 enzymexla::StructuredValueProperty::Hermitian);
96135 }
97136 if (valueProperties.isBroadcastedScalar ()) {
137+ anyKnown = true ;
98138 structuredValueProperties.push_back (
99139 enzymexla::StructuredValueProperty::BroadcastedScalar);
100140 }
0 commit comments