@@ -33,8 +33,6 @@ StructuredSparsityPattern::StructuredSparsityPattern(Value v) {
3333 return ;
3434 }
3535
36- llvm::errs () << " TODO: structured sparsity pattern not implemented for " << v
37- << " \n " ;
3836 setUnknown ();
3937 return ;
4038}
@@ -130,9 +128,8 @@ StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs,
130128 if (rhs.kind == StructuredSparsityKind::Unknown)
131129 return lhs;
132130
133- // for all other cases, we take the min of the bandwidths and refine
134- auto lb = std::min (lhs.lowerBandwidth , rhs.lowerBandwidth );
135- auto ub = std::min (lhs.upperBandwidth , rhs.upperBandwidth );
131+ auto lb = std::max (lhs.lowerBandwidth , rhs.lowerBandwidth );
132+ auto ub = std::max (lhs.upperBandwidth , rhs.upperBandwidth );
136133 auto newPattern = StructuredSparsityPattern (lb, ub);
137134 newPattern.refineKind ();
138135 return newPattern;
@@ -151,13 +148,21 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs,
151148 rhs.kind == StructuredSparsityKind::Unknown)
152149 return StructuredSparsityPattern (StructuredSparsityKind::Unknown);
153150
154- auto lb = std::max (lhs.lowerBandwidth , rhs.lowerBandwidth );
155- auto ub = std::max (lhs.upperBandwidth , rhs.upperBandwidth );
151+ auto lb = std::min (lhs.lowerBandwidth , rhs.lowerBandwidth );
152+ auto ub = std::min (lhs.upperBandwidth , rhs.upperBandwidth );
156153 auto newPattern = StructuredSparsityPattern (lb, ub);
157154 newPattern.refineKind ();
158155 return newPattern;
159156}
160157
158+ StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose (
159+ const StructuredSparsityPattern &op) {
160+ auto newPattern = StructuredSparsityPattern (op.upperBandwidth ,
161+ op.lowerBandwidth );
162+ newPattern.refineKind ();
163+ return newPattern;
164+ }
165+
161166void StructuredSparsityPattern::print (raw_ostream &os) const {
162167 switch (kind) {
163168 case StructuredSparsityKind::Unknown:
@@ -413,6 +418,31 @@ void StructuredMatrixType::print(raw_ostream &os) const {
413418 os << " )" ;
414419}
415420
421+ StructuredMatrixType StructuredMatrixType::propagateTranspose (
422+ const StructuredMatrixType &op) {
423+ return StructuredMatrixType (
424+ StructuredSparsityPattern::propagateTranspose (op.sparsityPattern ),
425+ op.valueProperties );
426+ }
427+
428+ StructuredMatrixType StructuredMatrixType::propagateAdd (
429+ const StructuredMatrixType &lhs, const StructuredMatrixType &rhs) {
430+ ValueProperties valProps;
431+ // TODO: If one is unit diag and other is zeros, we can propagate the other
432+ // to the unit diag
433+ if (lhs.getProperties ().isSymmetric () && rhs.getProperties ().isSymmetric ()) {
434+ valProps.set (ValueProperty::Symmetric);
435+ }
436+ if (lhs.getProperties ().isBroadcastedScalar () &&
437+ rhs.getProperties ().isBroadcastedScalar ()) {
438+ valProps.set (ValueProperty::BroadcastedScalar);
439+ }
440+
441+ return StructuredMatrixType (
442+ StructuredSparsityPattern::meet (lhs.sparsityPattern , rhs.sparsityPattern ),
443+ valProps);
444+ }
445+
416446// ===----------------------------------------------------------------------===//
417447// Lattice Element
418448// ===----------------------------------------------------------------------===//
@@ -465,13 +495,46 @@ void StructuredMatrixAnalysis::setToEntryState(
465495LogicalResult StructuredMatrixAnalysis::visitOperation (
466496 Operation *op, ArrayRef<const StructuredMatrixLattice *> operands,
467497 ArrayRef<StructuredMatrixLattice *> results) {
498+ SmallVector<bool > updatedProps (results.size (), false );
499+ SmallVector<StructuredMatrixType> propagatedProps (results.size ());
500+
501+ // transpose
502+ if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(op)) {
503+ updatedProps[0 ] = true ;
504+ propagatedProps[0 ] = StructuredMatrixType::propagateTranspose (
505+ operands[0 ]->getValue ());
506+ }
507+
508+ // elementwise
509+ // / add
510+ if (auto addOp = dyn_cast<stablehlo::AddOp>(op)) {
511+ updatedProps[0 ] = true ;
512+ propagatedProps[0 ] = StructuredMatrixType::propagateAdd (
513+ operands[0 ]->getValue (), operands[1 ]->getValue ());
514+ }
515+
516+ // / mul
517+
518+ // finalize
519+ for (size_t i = 0 ; i < results.size (); i++) {
520+ if (updatedProps[i]) {
521+ results[i]->setValue (
522+ StructuredMatrixType::join (results[i]->getValue (), propagatedProps[i]));
523+ }
524+ }
525+
468526
469527 llvm::errs () << " Visiting operation " << *op << " \n " ;
470528 for (auto operand : operands) {
471529 llvm::errs () << " operand: " ;
472530 operand->getValue ().print (llvm::errs ());
473531 llvm::errs () << " \n " ;
474532 }
533+ for (auto result : results) {
534+ llvm::errs () << " result: " ;
535+ result->getValue ().print (llvm::errs ());
536+ llvm::errs () << " \n " ;
537+ }
475538 llvm::errs () << " \n " ;
476539
477540 return success ();
0 commit comments