@@ -156,9 +156,9 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs,
156156}
157157
158158StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose (
159- const StructuredSparsityPattern &op) {
160- auto newPattern = StructuredSparsityPattern (op. upperBandwidth ,
161- op.lowerBandwidth );
159+ Value val, const StructuredSparsityPattern &op) {
160+ auto newPattern =
161+ StructuredSparsityPattern (op. upperBandwidth , op.lowerBandwidth );
162162 newPattern.refineKind ();
163163 return newPattern;
164164}
@@ -253,11 +253,52 @@ ValueProperties::ValueProperties(Value v) {
253253 }
254254 }
255255
256- // TODO: A x A^T will always be symmetric
256+ if (auto dotGeneralOp = dyn_cast<stablehlo::DotGeneralOp>(defOp)) {
257+ auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers ();
258+ auto lhs = dotGeneralOp.getLhs ();
259+ auto rhs = dotGeneralOp.getRhs ();
260+
261+ if (dotDimNumbers.getLhsBatchingDimensions ().size () == 0 &&
262+ dotDimNumbers.getRhsBatchingDimensions ().size () == 0 ) {
263+ // lhs == rhs => check for the dimension numbers
264+ if (lhs == rhs) {
265+ if (dotDimNumbers.getLhsContractingDimensions ().size () == 1 &&
266+ dotDimNumbers.getRhsContractingDimensions ().size () == 1 &&
267+ dotDimNumbers.getLhsContractingDimensions ()[0 ] ==
268+ dotDimNumbers.getRhsContractingDimensions ()[0 ]) {
269+ set (ValueProperty::Symmetric);
270+ }
271+ }
272+
273+ // check operands are transposed: `A x A^T` and `A^T x A`
274+ if (auto lhsT = lhs.getDefiningOp <stablehlo::TransposeOp>()) {
275+ if (isTrueTranspose (lhsT) && rhs == lhsT.getOperand ()) {
276+ if (dotDimNumbers.getLhsContractingDimensions ().size () == 1 &&
277+ dotDimNumbers.getRhsContractingDimensions ().size () == 1 &&
278+ dotDimNumbers.getLhsContractingDimensions ()[0 ] ==
279+ 1 - dotDimNumbers.getRhsContractingDimensions ()[0 ]) {
280+ set (ValueProperty::Symmetric);
281+ }
282+ }
283+ }
284+
285+ if (auto rhsT = rhs.getDefiningOp <stablehlo::TransposeOp>()) {
286+ if (isTrueTranspose (rhsT) && lhs == rhsT.getOperand ()) {
287+ if (dotDimNumbers.getLhsContractingDimensions ().size () == 1 &&
288+ dotDimNumbers.getRhsContractingDimensions ().size () == 1 &&
289+ dotDimNumbers.getLhsContractingDimensions ()[0 ] ==
290+ 1 - dotDimNumbers.getRhsContractingDimensions ()[0 ]) {
291+ set (ValueProperty::Symmetric);
292+ }
293+ }
294+ }
295+ }
296+ }
257297
258298 if (auto bcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(defOp)) {
259299 auto operand = bcastOp.getOperand ();
260- if (cast<RankedTensorType>(operand.getType ()).getRank () == 0 ) { // bcast(scalar)
300+ if (cast<RankedTensorType>(operand.getType ()).getRank () ==
301+ 0 ) { // bcast(scalar)
261302 if (matchPattern (operand, m_One ())) // bcast(1)
262303 set (ValueProperty::UnitDiagonal);
263304 set (ValueProperty::BroadcastedScalar);
@@ -418,31 +459,86 @@ void StructuredMatrixType::print(raw_ostream &os) const {
418459 os << " )" ;
419460}
420461
421- StructuredMatrixType StructuredMatrixType::propagateTranspose (
422- const StructuredMatrixType &op) {
462+ StructuredMatrixType
463+ StructuredMatrixType::propagateTranspose (Value val,
464+ const StructuredMatrixType &op) {
423465 return StructuredMatrixType (
424- StructuredSparsityPattern::propagateTranspose (op.sparsityPattern ),
466+ StructuredSparsityPattern::propagateTranspose (val, op.sparsityPattern ),
425467 op.valueProperties );
426468}
427469
428- StructuredMatrixType StructuredMatrixType::propagateAdd (
429- const StructuredMatrixType &lhs, const StructuredMatrixType &rhs) {
470+ StructuredMatrixType
471+ StructuredMatrixType::propagateAdd (Value lhs, Value rhs,
472+ const StructuredMatrixType &lhsType,
473+ const StructuredMatrixType &rhsType) {
430474 ValueProperties valProps;
431- // TODO: If one is unit diag and other is zeros, we can propagate the other
475+
476+ // If one is unit diag and other is zeros, we can propagate the other
432477 // to the unit diag
433- if (lhs.getProperties ().isSymmetric () && rhs.getProperties ().isSymmetric ()) {
478+ SplatElementsAttr lhsSplat, rhsSplat;
479+ if (lhsType.getProperties ().hasUnitDiagonal () &&
480+ matchPattern (lhs, m_Constant (&lhsSplat))) {
481+ if (utils::isZero (lhsSplat.getSplatValue <Attribute>())) {
482+ valProps.set (ValueProperty::UnitDiagonal);
483+ }
484+ }
485+ if (rhsType.getProperties ().hasUnitDiagonal () &&
486+ matchPattern (rhs, m_Constant (&rhsSplat))) {
487+ if (utils::isZero (rhsSplat.getSplatValue <Attribute>())) {
488+ valProps.set (ValueProperty::UnitDiagonal);
489+ }
490+ }
491+
492+ if (lhsType.getProperties ().isSymmetric () &&
493+ rhsType.getProperties ().isSymmetric ()) {
434494 valProps.set (ValueProperty::Symmetric);
435495 }
436- if (lhs .getProperties ().isBroadcastedScalar () &&
437- rhs .getProperties ().isBroadcastedScalar ()) {
496+ if (lhsType .getProperties ().isBroadcastedScalar () &&
497+ rhsType .getProperties ().isBroadcastedScalar ()) {
438498 valProps.set (ValueProperty::BroadcastedScalar);
439499 }
440500
441501 return StructuredMatrixType (
442- StructuredSparsityPattern::meet (lhs.sparsityPattern , rhs.sparsityPattern ),
502+ StructuredSparsityPattern::meet (lhsType.sparsityPattern ,
503+ rhsType.sparsityPattern ),
443504 valProps);
444505}
445506
507+ StructuredMatrixType
508+ StructuredMatrixType::propagateMultiply (Value lhs, Value rhs,
509+ const StructuredMatrixType &lhsType,
510+ const StructuredMatrixType &rhsType) {
511+ return StructuredMatrixType::meet (lhsType, rhsType);
512+ }
513+
514+ // TODO: we ideally want to special case elementwise ops that preserve certain
515+ // properties
516+ StructuredMatrixType StructuredMatrixType::propagateElementwise (
517+ ArrayRef<Value> operands,
518+ SmallVectorImpl<StructuredMatrixType> &operandsType) {
519+ // TODO: propagate structure
520+
521+ ValueProperties valueProperties;
522+ // TODO: propagate hermitian
523+ bool allSymmetric = true , allScalar = true ;
524+ for (auto opType : operandsType) {
525+ if (!opType.getProperties ().isSymmetric ()) {
526+ allSymmetric = false ;
527+ }
528+ if (!opType.getProperties ().isBroadcastedScalar ()) {
529+ allScalar = false ;
530+ }
531+ }
532+ if (allSymmetric) {
533+ valueProperties.set (ValueProperty::Symmetric);
534+ }
535+ if (allScalar) {
536+ valueProperties.set (ValueProperty::BroadcastedScalar);
537+ }
538+
539+ return StructuredMatrixType (StructuredSparsityPattern (), valueProperties);
540+ }
541+
446542// ===----------------------------------------------------------------------===//
447543// Lattice Element
448544// ===----------------------------------------------------------------------===//
@@ -498,45 +594,59 @@ LogicalResult StructuredMatrixAnalysis::visitOperation(
498594 SmallVector<bool > updatedProps (results.size (), false );
499595 SmallVector<StructuredMatrixType> propagatedProps (results.size ());
500596
597+ SmallVector<StructuredMatrixType> operandValues (operands.size ());
598+ for (size_t i = 0 ; i < operands.size (); i++) {
599+ operandValues[i] = operands[i]->getValue ();
600+ }
601+
501602 // transpose
502603 if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(op)) {
503604 updatedProps[0 ] = true ;
504605 propagatedProps[0 ] = StructuredMatrixType::propagateTranspose (
505- operands [0 ]-> getValue () );
606+ transposeOp. getOperand (), operandValues [0 ]);
506607 }
507608
508609 // elementwise
509610 // / add
510611 if (auto addOp = dyn_cast<stablehlo::AddOp>(op)) {
511612 updatedProps[0 ] = true ;
512613 propagatedProps[0 ] = StructuredMatrixType::propagateAdd (
513- operands [0 ]-> getValue (), operands [1 ]-> getValue () );
614+ addOp. getLhs (), addOp. getRhs (), operandValues [0 ], operandValues [1 ]);
514615 }
515616
516617 // / mul
618+ if (auto mulOp = dyn_cast<stablehlo::MulOp>(op)) {
619+ updatedProps[0 ] = true ;
620+ propagatedProps[0 ] = StructuredMatrixType::propagateMultiply (
621+ mulOp.getLhs (), mulOp.getRhs (), operandValues[0 ], operandValues[1 ]);
622+ }
623+
624+ // / fallback for other elementwise ops
625+ if (stablehlo::hasTraitElementwise (op)) {
626+ updatedProps[0 ] = true ;
627+ propagatedProps[0 ] = StructuredMatrixType::propagateElementwise (
628+ llvm::to_vector<3 >(op->getOperands ()), operandValues);
629+ }
630+
631+ // pass through ops
632+ if (isa<stablehlo::ConvertOp>(op)) {
633+ updatedProps[0 ] = true ;
634+ propagatedProps[0 ] = operandValues[0 ];
635+ }
517636
518637 // finalize
519638 for (size_t i = 0 ; i < results.size (); i++) {
520639 if (updatedProps[i]) {
521- results[i]->setValue (
522- StructuredMatrixType::join (results[i]->getValue (), propagatedProps[i]));
640+ auto resultOrig = results[i]->getValue ();
641+ auto resultNew =
642+ StructuredMatrixType::join (resultOrig, propagatedProps[i]);
643+ results[i]->setValue (resultNew);
644+ propagateIfChanged (results[i], resultNew == resultOrig
645+ ? ChangeResult::NoChange
646+ : ChangeResult::Change);
523647 }
524648 }
525649
526-
527- llvm::errs () << " Visiting operation " << *op << " \n " ;
528- for (auto operand : operands) {
529- llvm::errs () << " operand: " ;
530- operand->getValue ().print (llvm::errs ());
531- llvm::errs () << " \n " ;
532- }
533- for (auto result : results) {
534- llvm::errs () << " result: " ;
535- result->getValue ().print (llvm::errs ());
536- llvm::errs () << " \n " ;
537- }
538- llvm::errs () << " \n " ;
539-
540650 return success ();
541651}
542652
0 commit comments