Skip to content

Commit a519243

Browse files
committed
feat: more propagation rules
1 parent 04419d2 commit a519243

File tree

3 files changed

+168
-41
lines changed

3 files changed

+168
-41
lines changed

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp

Lines changed: 143 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs,
156156
}
157157

158158
StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose(
159-
const StructuredSparsityPattern &op) {
160-
auto newPattern = StructuredSparsityPattern(op.upperBandwidth,
161-
op.lowerBandwidth);
159+
Value val, const StructuredSparsityPattern &op) {
160+
auto newPattern =
161+
StructuredSparsityPattern(op.upperBandwidth, op.lowerBandwidth);
162162
newPattern.refineKind();
163163
return newPattern;
164164
}
@@ -253,11 +253,52 @@ ValueProperties::ValueProperties(Value v) {
253253
}
254254
}
255255

256-
// TODO: A x A^T will always be symmetric
256+
if (auto dotGeneralOp = dyn_cast<stablehlo::DotGeneralOp>(defOp)) {
257+
auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers();
258+
auto lhs = dotGeneralOp.getLhs();
259+
auto rhs = dotGeneralOp.getRhs();
260+
261+
if (dotDimNumbers.getLhsBatchingDimensions().size() == 0 &&
262+
dotDimNumbers.getRhsBatchingDimensions().size() == 0) {
263+
// lhs == rhs => check for the dimension numbers
264+
if (lhs == rhs) {
265+
if (dotDimNumbers.getLhsContractingDimensions().size() == 1 &&
266+
dotDimNumbers.getRhsContractingDimensions().size() == 1 &&
267+
dotDimNumbers.getLhsContractingDimensions()[0] ==
268+
dotDimNumbers.getRhsContractingDimensions()[0]) {
269+
set(ValueProperty::Symmetric);
270+
}
271+
}
272+
273+
// check operands are transposed: `A x A^T` and `A^T x A`
274+
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
275+
if (isTrueTranspose(lhsT) && rhs == lhsT.getOperand()) {
276+
if (dotDimNumbers.getLhsContractingDimensions().size() == 1 &&
277+
dotDimNumbers.getRhsContractingDimensions().size() == 1 &&
278+
dotDimNumbers.getLhsContractingDimensions()[0] ==
279+
1 - dotDimNumbers.getRhsContractingDimensions()[0]) {
280+
set(ValueProperty::Symmetric);
281+
}
282+
}
283+
}
284+
285+
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
286+
if (isTrueTranspose(rhsT) && lhs == rhsT.getOperand()) {
287+
if (dotDimNumbers.getLhsContractingDimensions().size() == 1 &&
288+
dotDimNumbers.getRhsContractingDimensions().size() == 1 &&
289+
dotDimNumbers.getLhsContractingDimensions()[0] ==
290+
1 - dotDimNumbers.getRhsContractingDimensions()[0]) {
291+
set(ValueProperty::Symmetric);
292+
}
293+
}
294+
}
295+
}
296+
}
257297

258298
if (auto bcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(defOp)) {
259299
auto operand = bcastOp.getOperand();
260-
if (cast<RankedTensorType>(operand.getType()).getRank() == 0) { // bcast(scalar)
300+
if (cast<RankedTensorType>(operand.getType()).getRank() ==
301+
0) { // bcast(scalar)
261302
if (matchPattern(operand, m_One())) // bcast(1)
262303
set(ValueProperty::UnitDiagonal);
263304
set(ValueProperty::BroadcastedScalar);
@@ -418,31 +459,86 @@ void StructuredMatrixType::print(raw_ostream &os) const {
418459
os << ")";
419460
}
420461

421-
StructuredMatrixType StructuredMatrixType::propagateTranspose(
422-
const StructuredMatrixType &op) {
462+
StructuredMatrixType
463+
StructuredMatrixType::propagateTranspose(Value val,
464+
const StructuredMatrixType &op) {
423465
return StructuredMatrixType(
424-
StructuredSparsityPattern::propagateTranspose(op.sparsityPattern),
466+
StructuredSparsityPattern::propagateTranspose(val, op.sparsityPattern),
425467
op.valueProperties);
426468
}
427469

428-
StructuredMatrixType StructuredMatrixType::propagateAdd(
429-
const StructuredMatrixType &lhs, const StructuredMatrixType &rhs) {
470+
StructuredMatrixType
471+
StructuredMatrixType::propagateAdd(Value lhs, Value rhs,
472+
const StructuredMatrixType &lhsType,
473+
const StructuredMatrixType &rhsType) {
430474
ValueProperties valProps;
431-
// TODO: If one is unit diag and other is zeros, we can propagate the other
475+
476+
// If one is unit diag and other is zeros, we can propagate the other
432477
// to the unit diag
433-
if (lhs.getProperties().isSymmetric() && rhs.getProperties().isSymmetric()) {
478+
SplatElementsAttr lhsSplat, rhsSplat;
479+
if (lhsType.getProperties().hasUnitDiagonal() &&
480+
matchPattern(lhs, m_Constant(&lhsSplat))) {
481+
if (utils::isZero(lhsSplat.getSplatValue<Attribute>())) {
482+
valProps.set(ValueProperty::UnitDiagonal);
483+
}
484+
}
485+
if (rhsType.getProperties().hasUnitDiagonal() &&
486+
matchPattern(rhs, m_Constant(&rhsSplat))) {
487+
if (utils::isZero(rhsSplat.getSplatValue<Attribute>())) {
488+
valProps.set(ValueProperty::UnitDiagonal);
489+
}
490+
}
491+
492+
if (lhsType.getProperties().isSymmetric() &&
493+
rhsType.getProperties().isSymmetric()) {
434494
valProps.set(ValueProperty::Symmetric);
435495
}
436-
if (lhs.getProperties().isBroadcastedScalar() &&
437-
rhs.getProperties().isBroadcastedScalar()) {
496+
if (lhsType.getProperties().isBroadcastedScalar() &&
497+
rhsType.getProperties().isBroadcastedScalar()) {
438498
valProps.set(ValueProperty::BroadcastedScalar);
439499
}
440500

441501
return StructuredMatrixType(
442-
StructuredSparsityPattern::meet(lhs.sparsityPattern, rhs.sparsityPattern),
502+
StructuredSparsityPattern::meet(lhsType.sparsityPattern,
503+
rhsType.sparsityPattern),
443504
valProps);
444505
}
445506

507+
StructuredMatrixType
508+
StructuredMatrixType::propagateMultiply(Value lhs, Value rhs,
509+
const StructuredMatrixType &lhsType,
510+
const StructuredMatrixType &rhsType) {
511+
return StructuredMatrixType::meet(lhsType, rhsType);
512+
}
513+
514+
// TODO: we ideally want to special case elementwise ops that preserve certain
515+
// properties
516+
StructuredMatrixType StructuredMatrixType::propagateElementwise(
517+
ArrayRef<Value> operands,
518+
SmallVectorImpl<StructuredMatrixType> &operandsType) {
519+
// TODO: propagate structure
520+
521+
ValueProperties valueProperties;
522+
// TODO: propagate hermitian
523+
bool allSymmetric = true, allScalar = true;
524+
for (auto opType : operandsType) {
525+
if (!opType.getProperties().isSymmetric()) {
526+
allSymmetric = false;
527+
}
528+
if (!opType.getProperties().isBroadcastedScalar()) {
529+
allScalar = false;
530+
}
531+
}
532+
if (allSymmetric) {
533+
valueProperties.set(ValueProperty::Symmetric);
534+
}
535+
if (allScalar) {
536+
valueProperties.set(ValueProperty::BroadcastedScalar);
537+
}
538+
539+
return StructuredMatrixType(StructuredSparsityPattern(), valueProperties);
540+
}
541+
446542
//===----------------------------------------------------------------------===//
447543
// Lattice Element
448544
//===----------------------------------------------------------------------===//
@@ -498,45 +594,59 @@ LogicalResult StructuredMatrixAnalysis::visitOperation(
498594
SmallVector<bool> updatedProps(results.size(), false);
499595
SmallVector<StructuredMatrixType> propagatedProps(results.size());
500596

597+
SmallVector<StructuredMatrixType> operandValues(operands.size());
598+
for (size_t i = 0; i < operands.size(); i++) {
599+
operandValues[i] = operands[i]->getValue();
600+
}
601+
501602
// transpose
502603
if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(op)) {
503604
updatedProps[0] = true;
504605
propagatedProps[0] = StructuredMatrixType::propagateTranspose(
505-
operands[0]->getValue());
606+
transposeOp.getOperand(), operandValues[0]);
506607
}
507608

508609
// elementwise
509610
/// add
510611
if (auto addOp = dyn_cast<stablehlo::AddOp>(op)) {
511612
updatedProps[0] = true;
512613
propagatedProps[0] = StructuredMatrixType::propagateAdd(
513-
operands[0]->getValue(), operands[1]->getValue());
614+
addOp.getLhs(), addOp.getRhs(), operandValues[0], operandValues[1]);
514615
}
515616

516617
/// mul
618+
if (auto mulOp = dyn_cast<stablehlo::MulOp>(op)) {
619+
updatedProps[0] = true;
620+
propagatedProps[0] = StructuredMatrixType::propagateMultiply(
621+
mulOp.getLhs(), mulOp.getRhs(), operandValues[0], operandValues[1]);
622+
}
623+
624+
/// fallback for other elementwise ops
625+
if (stablehlo::hasTraitElementwise(op)) {
626+
updatedProps[0] = true;
627+
propagatedProps[0] = StructuredMatrixType::propagateElementwise(
628+
llvm::to_vector<3>(op->getOperands()), operandValues);
629+
}
630+
631+
// pass through ops
632+
if (isa<stablehlo::ConvertOp>(op)) {
633+
updatedProps[0] = true;
634+
propagatedProps[0] = operandValues[0];
635+
}
517636

518637
// finalize
519638
for (size_t i = 0; i < results.size(); i++) {
520639
if (updatedProps[i]) {
521-
results[i]->setValue(
522-
StructuredMatrixType::join(results[i]->getValue(), propagatedProps[i]));
640+
auto resultOrig = results[i]->getValue();
641+
auto resultNew =
642+
StructuredMatrixType::join(resultOrig, propagatedProps[i]);
643+
results[i]->setValue(resultNew);
644+
propagateIfChanged(results[i], resultNew == resultOrig
645+
? ChangeResult::NoChange
646+
: ChangeResult::Change);
523647
}
524648
}
525649

526-
527-
llvm::errs() << "Visiting operation " << *op << "\n";
528-
for (auto operand : operands) {
529-
llvm::errs() << " operand: ";
530-
operand->getValue().print(llvm::errs());
531-
llvm::errs() << "\n";
532-
}
533-
for (auto result : results) {
534-
llvm::errs() << " result: ";
535-
result->getValue().print(llvm::errs());
536-
llvm::errs() << "\n";
537-
}
538-
llvm::errs() << "\n";
539-
540650
return success();
541651
}
542652

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ namespace structure_analysis {
1212

1313
namespace utils {
1414

15+
static bool isZero(APInt v) { return v.isZero(); }
16+
static bool isZero(APFloat v) { return v.isZero(); }
17+
static bool isZero(Attribute v) {
18+
if (auto intAttr = dyn_cast<IntegerAttr>(v))
19+
return isZero(intAttr.getValue());
20+
if (auto floatAttr = dyn_cast<FloatAttr>(v))
21+
return isZero(floatAttr.getValue());
22+
return false;
23+
}
24+
1525
static bool isOne(APInt v) { return v.isOne(); }
1626
static bool isOne(APFloat v) { return v.isExactlyValue(1.0); }
1727
static bool isOne(Attribute v) {
@@ -89,8 +99,8 @@ class StructuredSparsityPattern {
8999
}
90100

91101
// propagation rules
92-
static StructuredSparsityPattern propagateTranspose(
93-
const StructuredSparsityPattern &op);
102+
static StructuredSparsityPattern
103+
propagateTranspose(Value val, const StructuredSparsityPattern &op);
94104

95105
private:
96106
void initializeBandwidths();
@@ -208,13 +218,20 @@ class StructuredMatrixType {
208218
}
209219

210220
// propagation rules
211-
static StructuredMatrixType propagateTranspose(const StructuredMatrixType &op);
221+
static StructuredMatrixType
222+
propagateTranspose(Value val, const StructuredMatrixType &op);
223+
224+
static StructuredMatrixType propagateAdd(Value lhs, Value rhs,
225+
const StructuredMatrixType &lhsType,
226+
const StructuredMatrixType &rhsType);
212227

213-
static StructuredMatrixType propagateAdd(const StructuredMatrixType &lhs,
214-
const StructuredMatrixType &rhs);
228+
static StructuredMatrixType
229+
propagateMultiply(Value lhs, Value rhs, const StructuredMatrixType &lhsType,
230+
const StructuredMatrixType &rhsType);
215231

216-
static StructuredMatrixType propagateMultiply(const StructuredMatrixType &lhs,
217-
const StructuredMatrixType &rhs);
232+
static StructuredMatrixType
233+
propagateElementwise(ArrayRef<Value> operands,
234+
SmallVectorImpl<StructuredMatrixType> &operandsType);
218235

219236
// TODO: implement queries that check both the sparsity pattern and value
220237
// properties and return specific matrix kinds

src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class StructuredMatrixSimplifyPass
147147
}
148148

149149
if (anyKnown) {
150-
op->setAttr("structured_sparsity",
150+
op->setAttr("enzymexla.structured_sparsity",
151151
ArrayAttr::get(mod.getContext(), structuredSparsityAttrs));
152152
}
153153

0 commit comments

Comments
 (0)