Skip to content

Commit 04419d2

Browse files
committed
feat: more propagation rules
1 parent 3da6a56 commit 04419d2

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ StructuredSparsityPattern::StructuredSparsityPattern(Value v) {
3333
return;
3434
}
3535

36-
llvm::errs() << "TODO: structured sparsity pattern not implemented for " << v
37-
<< "\n";
3836
setUnknown();
3937
return;
4038
}
@@ -130,9 +128,8 @@ StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs,
130128
if (rhs.kind == StructuredSparsityKind::Unknown)
131129
return lhs;
132130

133-
// for all other cases, we take the min of the bandwidths and refine
134-
auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth);
135-
auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth);
131+
auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth);
132+
auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth);
136133
auto newPattern = StructuredSparsityPattern(lb, ub);
137134
newPattern.refineKind();
138135
return newPattern;
@@ -151,13 +148,21 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs,
151148
rhs.kind == StructuredSparsityKind::Unknown)
152149
return StructuredSparsityPattern(StructuredSparsityKind::Unknown);
153150

154-
auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth);
155-
auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth);
151+
auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth);
152+
auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth);
156153
auto newPattern = StructuredSparsityPattern(lb, ub);
157154
newPattern.refineKind();
158155
return newPattern;
159156
}
160157

158+
StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose(
159+
const StructuredSparsityPattern &op) {
160+
auto newPattern = StructuredSparsityPattern(op.upperBandwidth,
161+
op.lowerBandwidth);
162+
newPattern.refineKind();
163+
return newPattern;
164+
}
165+
161166
void StructuredSparsityPattern::print(raw_ostream &os) const {
162167
switch (kind) {
163168
case StructuredSparsityKind::Unknown:
@@ -413,6 +418,31 @@ void StructuredMatrixType::print(raw_ostream &os) const {
413418
os << ")";
414419
}
415420

421+
StructuredMatrixType StructuredMatrixType::propagateTranspose(
422+
const StructuredMatrixType &op) {
423+
return StructuredMatrixType(
424+
StructuredSparsityPattern::propagateTranspose(op.sparsityPattern),
425+
op.valueProperties);
426+
}
427+
428+
StructuredMatrixType StructuredMatrixType::propagateAdd(
429+
const StructuredMatrixType &lhs, const StructuredMatrixType &rhs) {
430+
ValueProperties valProps;
431+
// TODO: If one is unit diag and other is zeros, we can propagate the other
432+
// to the unit diag
433+
if (lhs.getProperties().isSymmetric() && rhs.getProperties().isSymmetric()) {
434+
valProps.set(ValueProperty::Symmetric);
435+
}
436+
if (lhs.getProperties().isBroadcastedScalar() &&
437+
rhs.getProperties().isBroadcastedScalar()) {
438+
valProps.set(ValueProperty::BroadcastedScalar);
439+
}
440+
441+
return StructuredMatrixType(
442+
StructuredSparsityPattern::meet(lhs.sparsityPattern, rhs.sparsityPattern),
443+
valProps);
444+
}
445+
416446
//===----------------------------------------------------------------------===//
417447
// Lattice Element
418448
//===----------------------------------------------------------------------===//
@@ -465,13 +495,46 @@ void StructuredMatrixAnalysis::setToEntryState(
465495
LogicalResult StructuredMatrixAnalysis::visitOperation(
466496
Operation *op, ArrayRef<const StructuredMatrixLattice *> operands,
467497
ArrayRef<StructuredMatrixLattice *> results) {
498+
SmallVector<bool> updatedProps(results.size(), false);
499+
SmallVector<StructuredMatrixType> propagatedProps(results.size());
500+
501+
// transpose
502+
if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(op)) {
503+
updatedProps[0] = true;
504+
propagatedProps[0] = StructuredMatrixType::propagateTranspose(
505+
operands[0]->getValue());
506+
}
507+
508+
// elementwise
509+
/// add
510+
if (auto addOp = dyn_cast<stablehlo::AddOp>(op)) {
511+
updatedProps[0] = true;
512+
propagatedProps[0] = StructuredMatrixType::propagateAdd(
513+
operands[0]->getValue(), operands[1]->getValue());
514+
}
515+
516+
/// mul
517+
518+
// finalize
519+
for (size_t i = 0; i < results.size(); i++) {
520+
if (updatedProps[i]) {
521+
results[i]->setValue(
522+
StructuredMatrixType::join(results[i]->getValue(), propagatedProps[i]));
523+
}
524+
}
525+
468526

469527
llvm::errs() << "Visiting operation " << *op << "\n";
470528
for (auto operand : operands) {
471529
llvm::errs() << " operand: ";
472530
operand->getValue().print(llvm::errs());
473531
llvm::errs() << "\n";
474532
}
533+
for (auto result : results) {
534+
llvm::errs() << " result: ";
535+
result->getValue().print(llvm::errs());
536+
llvm::errs() << "\n";
537+
}
475538
llvm::errs() << "\n";
476539

477540
return success();

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class StructuredSparsityPattern {
8888
return os;
8989
}
9090

91+
// propagation rules
92+
static StructuredSparsityPattern propagateTranspose(
93+
const StructuredSparsityPattern &op);
94+
9195
private:
9296
void initializeBandwidths();
9397
void refineKind();
@@ -203,7 +207,14 @@ class StructuredMatrixType {
203207
return os;
204208
}
205209

206-
// TODO: propagation rules probably goes in here
210+
// propagation rules
211+
static StructuredMatrixType propagateTranspose(const StructuredMatrixType &op);
212+
213+
static StructuredMatrixType propagateAdd(const StructuredMatrixType &lhs,
214+
const StructuredMatrixType &rhs);
215+
216+
static StructuredMatrixType propagateMultiply(const StructuredMatrixType &lhs,
217+
const StructuredMatrixType &rhs);
207218

208219
// TODO: implement queries that check both the sparsity pattern and value
209220
// properties and return specific matrix kinds

0 commit comments

Comments
 (0)