From d6f11e5e70bf8586483ea74b1db05ff65db79c99 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 10:07:58 -0600 Subject: [PATCH 01/11] feat: general structured matrix lattice --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 56 ++++ .../jax/Analysis/StructuredMatrixAnalysis.h | 249 ++++++++++++++++++ src/enzyme_ad/jax/BUILD | 22 ++ 3 files changed, 327 insertions(+) create mode 100644 src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp create mode 100644 src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp new file mode 100644 index 000000000..948b5c331 --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -0,0 +1,56 @@ +#include "src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace mlir { +namespace structure_analysis { + +//===----------------------------------------------------------------------===// +// Structured Sparsity Pattern Implementation +//===----------------------------------------------------------------------===// + +void StructuredSparsityPattern::initializeBandwidths() { + switch (kind) { + case StructuredSparsityKind::Diagonal: + lowerBandwidth = 0; + upperBandwidth = 0; + break; + case StructuredSparsityKind::Bidiagonal: + lowerBandwidth = 0; + upperBandwidth = 1; + break; + case StructuredSparsityKind::Tridiagonal: + lowerBandwidth = 1; + upperBandwidth = 1; + break; + case StructuredSparsityKind::UpperTriangular: + lowerBandwidth = 0; + upperBandwidth = std::numeric_limits::max(); + break; + case StructuredSparsityKind::LowerTriangular: + lowerBandwidth = std::numeric_limits::max(); + upperBandwidth = 0; + break; + default: + break; + } +} + +//===----------------------------------------------------------------------===// +// Value Properties Implementation +//===----------------------------------------------------------------------===// + +} // namespace structure_analysis +} // namespace mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h new file mode 100644 index 000000000..b630d55d4 --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -0,0 +1,249 @@ +#include +#include + +namespace mlir { +namespace structure_analysis { + +//===----------------------------------------------------------------------===// +// Structured Sparsity Pattern Implementation +//===----------------------------------------------------------------------===// + +enum class StructuredSparsityKind { + Unknown, + Dense, + Band, + UpperTriangular, + LowerTriangular, + Diagonal, + Bidiagonal, + Tridiagonal, + Empty, +}; + +// TODO: currently only legal negative value is -1, which means "unknown" +// we should support negative bandwidths +class StructuredSparsityPattern { +public: + StructuredSparsityPattern() + : kind(StructuredSparsityKind::Unknown), lowerBandwidth(0), + upperBandwidth(0) {} + + explicit StructuredSparsityPattern(StructuredSparsityKind kind) + : kind(kind), lowerBandwidth(-1), upperBandwidth(-1) { + initializeBandwidths(); + } + + StructuredSparsityPattern(int64_t lowerBandwidth, int64_t upperBandwidth) + : kind(StructuredSparsityKind::Band), lowerBandwidth(lowerBandwidth), + upperBandwidth(upperBandwidth) {} + + static StructuredSparsityPattern meet(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs) { + if (lhs.kind == StructuredSparsityKind::Empty || + rhs.kind == StructuredSparsityKind::Empty) + return StructuredSparsityPattern(StructuredSparsityKind::Empty); + if (lhs.kind == StructuredSparsityKind::Unknown) + return rhs; + if (rhs.kind == StructuredSparsityKind::Unknown) + return lhs; + + if (lhs.kind == StructuredSparsityKind::Band && + rhs.kind == StructuredSparsityKind::Band) { + return StructuredSparsityPattern( + std::min(lhs.lowerBandwidth, rhs.lowerBandwidth), + std::min(lhs.upperBandwidth, rhs.upperBandwidth)); + } + + return lhs <= rhs ? lhs : rhs; + } + + static StructuredSparsityPattern join(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs) { + if (lhs.kind == StructuredSparsityKind::Unknown || + rhs.kind == StructuredSparsityKind::Unknown) + return StructuredSparsityPattern(StructuredSparsityKind::Unknown); + if (lhs.kind == StructuredSparsityKind::Empty) + return rhs; + if (rhs.kind == StructuredSparsityKind::Empty) + return lhs; + + if (lhs.kind == StructuredSparsityKind::Band && + rhs.kind == StructuredSparsityKind::Band) { + return StructuredSparsityPattern( + std::max(lhs.lowerBandwidth, rhs.lowerBandwidth), + std::max(lhs.upperBandwidth, rhs.upperBandwidth)); + } + + return StructuredSparsityPattern(StructuredSparsityKind::Dense); + } + + bool operator==(const StructuredSparsityPattern &other) const {} + + bool operator<=(const StructuredSparsityPattern &other) const { + if (kind == StructuredSparsityKind::Empty) + return true; + + if (other.kind == StructuredSparsityKind::Unknown) + return true; + + if (other.kind == StructuredSparsityKind::Empty) + return kind == StructuredSparsityKind::Empty; + + if (kind == StructuredSparsityKind::Unknown) + return other.kind == StructuredSparsityKind::Unknown; + + if (kind == other.kind) { + if (kind == StructuredSparsityKind::Band) { + return lowerBandwidth <= other.lowerBandwidth && + upperBandwidth <= other.upperBandwidth; + } + return true; + } + + if (kind == StructuredSparsityKind::Diagonal) { + return other.kind != StructuredSparsityKind::Empty; + } + + if (kind == StructuredSparsityKind::Bidiagonal) { + return other.kind == StructuredSparsityKind::Tridiagonal || + other.kind == StructuredSparsityKind::Band || + other.kind == StructuredSparsityKind::UpperTriangular || + other.kind == StructuredSparsityKind::Dense; + } + + if (kind == StructuredSparsityKind::Tridiagonal) { + return other.kind == StructuredSparsityKind::Band || + other.kind == StructuredSparsityKind::Dense; + } + + if (kind == StructuredSparsityKind::UpperTriangular || + kind == StructuredSparsityKind::LowerTriangular) { + if (other.kind == StructuredSparsityKind::Dense) + return true; + if (other.kind == StructuredSparsityKind::Band) { + if (kind == StructuredSparsityKind::UpperTriangular) { + return other.lowerBandwidth == 0; + } else { + return other.upperBandwidth == 0; + } + } + return false; + } + + if (kind == StructuredSparsityKind::Band) { + return other.kind == StructuredSparsityKind::Dense; + } + + if (kind == StructuredSparsityKind::Dense) { + return other.kind == StructuredSparsityKind::Dense || + other.kind == StructuredSparsityKind::Unknown; + } + + return false; + } + +private: + void initializeBandwidths(); + + StructuredSparsityKind kind; + int64_t lowerBandwidth; + int64_t upperBandwidth; +}; + +//===----------------------------------------------------------------------===// +// Value Properties Implementation +//===----------------------------------------------------------------------===// + +enum class ValueProperty { + UnitDiagonal = 1 << 0, + Symmetric = 1 << 1, + Hermitian = 1 << 2, +}; + +class ValueProperties { +public: + ValueProperties() = default; + explicit ValueProperties(uint32_t flags) : flags(flags) {} + + void set(ValueProperty property) { flags |= static_cast(property); } + void clear(ValueProperty property) { + flags &= ~static_cast(property); + } + bool has(ValueProperty property) const { + return flags & static_cast(property); + } + + bool hasUnitDiagonal() const { return has(ValueProperty::UnitDiagonal); } + bool isSymmetric() const { return has(ValueProperty::Symmetric); } + bool isHermitian() const { return has(ValueProperty::Hermitian); } + + uint32_t getFlags() const { return flags; } + + // partial ordering + static ValueProperties meet(const ValueProperties &lhs, + const ValueProperties &rhs) { + return ValueProperties(lhs.flags & rhs.flags); + } + + static ValueProperties join(const ValueProperties &lhs, + const ValueProperties &rhs) { + return ValueProperties(lhs.flags | rhs.flags); + } + + bool operator==(const ValueProperties &other) const { + return flags == other.flags; + } + + bool operator<=(const ValueProperties &other) const { + return (flags & other.flags) == flags; + } + +private: + uint32_t flags; +}; + +//===----------------------------------------------------------------------===// +// Structured Matrix Type +//===----------------------------------------------------------------------===// + +class StructuredMatrixType { +public: + StructuredMatrixType() = default; + StructuredMatrixType(StructuredSparsityPattern sparsityPattern, + ValueProperties valueProperties) + : sparsityPattern(sparsityPattern), valueProperties(valueProperties) {} + + const StructuredSparsityPattern &getSparsityPattern() const { + return sparsityPattern; + } + const ValueProperties &getProperties() const { return valueProperties; } + + // partial ordering + static StructuredMatrixType meet(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs) { + return StructuredMatrixType( + StructuredSparsityPattern::meet(lhs.sparsityPattern, + rhs.sparsityPattern), + ValueProperties::meet(lhs.valueProperties, rhs.valueProperties)); + } + + static StructuredMatrixType join(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs) { + return StructuredMatrixType( + StructuredSparsityPattern::join(lhs.sparsityPattern, + rhs.sparsityPattern), + ValueProperties::join(lhs.valueProperties, rhs.valueProperties)); + } + + bool operator==(const StructuredMatrixType &other) const { + return sparsityPattern == other.sparsityPattern && + valueProperties == other.valueProperties; + } + +private: + StructuredSparsityPattern sparsityPattern; + ValueProperties valueProperties; +}; + +} // namespace structure_analysis +} // namespace mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 975f166f9..20ba941ce 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -836,6 +836,28 @@ cc_library( ], ) +cc_library( + name = "StructuredMatrixAnalysis", + srcs = ["Analysis/StructuredMatrixAnalysis.cpp"], + hdrs = ["Analysis/StructuredMatrixAnalysis.h"], + deps = [ + # MLIR Core + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + + # MLIR Analysis + "@llvm-project//mlir:Analysis", + + # MLIR Dialects + "@stablehlo//:stablehlo_ops", + + # LLVM Support + "@llvm-project//llvm:Support", + ], +) + cc_library( name = "XLADerivatives", srcs = glob([ From 474c8bfe28483044888c3c74569c67400c640410 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 12:51:19 -0600 Subject: [PATCH 02/11] fix: meet and join --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 72 ++++++++-- .../jax/Analysis/StructuredMatrixAnalysis.h | 126 ++++-------------- 2 files changed, 88 insertions(+), 110 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index 948b5c331..2828199f8 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -23,31 +23,85 @@ namespace structure_analysis { void StructuredSparsityPattern::initializeBandwidths() { switch (kind) { - case StructuredSparsityKind::Diagonal: + case StructuredSparsityKind::Unknown: + break; // leave as is + case StructuredSparsityKind::Dense: + lowerBandwidth = std::numeric_limits::max(); + upperBandwidth = std::numeric_limits::max(); + case StructuredSparsityKind::Band: + llvm_unreachable("constructing band with no bandwidths"); + case StructuredSparsityKind::UpperTriangular: lowerBandwidth = 0; - upperBandwidth = 0; + upperBandwidth = std::numeric_limits::max(); break; - case StructuredSparsityKind::Bidiagonal: + case StructuredSparsityKind::UpperBidiagonal: lowerBandwidth = 0; upperBandwidth = 1; break; + case StructuredSparsityKind::LowerTriangular: + lowerBandwidth = std::numeric_limits::max(); + upperBandwidth = 0; + break; + case StructuredSparsityKind::LowerBidiagonal: + lowerBandwidth = 1; + upperBandwidth = 0; + break; case StructuredSparsityKind::Tridiagonal: lowerBandwidth = 1; upperBandwidth = 1; break; - case StructuredSparsityKind::UpperTriangular: + case StructuredSparsityKind::Diagonal: lowerBandwidth = 0; - upperBandwidth = std::numeric_limits::max(); - break; - case StructuredSparsityKind::LowerTriangular: - lowerBandwidth = std::numeric_limits::max(); upperBandwidth = 0; break; - default: + case StructuredSparsityKind::Empty: break; } } +void StructuredSparsityPattern::refineKind() { + if (kind != StructuredSparsityKind::Band) + return; + + if (lowerBandwidth == 0) { + if (upperBandwidth == 0) { + kind = StructuredSparsityKind::Diagonal; + return; + } + if (upperBandwidth == 1) { + kind = StructuredSparsityKind::UpperBidiagonal; + return; + } + if (upperBandwidth == std::numeric_limits::max()) { + kind = StructuredSparsityKind::UpperTriangular; + return; + } + } + + // lowerBandwidth != 0 + if (upperBandwidth == 0) { + if (lowerBandwidth == 1) { + kind = StructuredSparsityKind::LowerBidiagonal; + return; + } + if (lowerBandwidth == std::numeric_limits::max()) { + kind = StructuredSparsityKind::LowerTriangular; + return; + } + } + + if (lowerBandwidth == 1 && upperBandwidth == 1) { + kind = StructuredSparsityKind::Tridiagonal; + return; + } + + if (lowerBandwidth == std::numeric_limits::max() && + upperBandwidth == std::numeric_limits::max()) { + kind = StructuredSparsityKind::Dense; + return; + } +} + //===----------------------------------------------------------------------===// // Value Properties Implementation //===----------------------------------------------------------------------===// diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index b630d55d4..67b489bb3 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -13,11 +13,12 @@ enum class StructuredSparsityKind { Dense, Band, UpperTriangular, + UpperBidiagonal, LowerTriangular, - Diagonal, - Bidiagonal, + LowerBidiagonal, Tridiagonal, - Empty, + Diagonal, + Empty, // doesn't really mean anything, but we need it for bottom element }; // TODO: currently only legal negative value is -1, which means "unknown" @@ -35,115 +36,52 @@ class StructuredSparsityPattern { StructuredSparsityPattern(int64_t lowerBandwidth, int64_t upperBandwidth) : kind(StructuredSparsityKind::Band), lowerBandwidth(lowerBandwidth), - upperBandwidth(upperBandwidth) {} + upperBandwidth(upperBandwidth) { + refineKind(); + } + // most precise of lhs and rhs static StructuredSparsityPattern meet(const StructuredSparsityPattern &lhs, const StructuredSparsityPattern &rhs) { if (lhs.kind == StructuredSparsityKind::Empty || rhs.kind == StructuredSparsityKind::Empty) return StructuredSparsityPattern(StructuredSparsityKind::Empty); + if (lhs.kind == StructuredSparsityKind::Unknown) return rhs; if (rhs.kind == StructuredSparsityKind::Unknown) return lhs; - if (lhs.kind == StructuredSparsityKind::Band && - rhs.kind == StructuredSparsityKind::Band) { - return StructuredSparsityPattern( - std::min(lhs.lowerBandwidth, rhs.lowerBandwidth), - std::min(lhs.upperBandwidth, rhs.upperBandwidth)); - } - - return lhs <= rhs ? lhs : rhs; + // for all other cases, we take the min of the bandwidths and refine + auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); + auto newPattern = StructuredSparsityPattern(lb, ub); + newPattern.refineKind(); + return newPattern; } + // least precise of lhs and rhs static StructuredSparsityPattern join(const StructuredSparsityPattern &lhs, const StructuredSparsityPattern &rhs) { - if (lhs.kind == StructuredSparsityKind::Unknown || - rhs.kind == StructuredSparsityKind::Unknown) - return StructuredSparsityPattern(StructuredSparsityKind::Unknown); if (lhs.kind == StructuredSparsityKind::Empty) return rhs; if (rhs.kind == StructuredSparsityKind::Empty) return lhs; - if (lhs.kind == StructuredSparsityKind::Band && - rhs.kind == StructuredSparsityKind::Band) { - return StructuredSparsityPattern( - std::max(lhs.lowerBandwidth, rhs.lowerBandwidth), - std::max(lhs.upperBandwidth, rhs.upperBandwidth)); - } - - return StructuredSparsityPattern(StructuredSparsityKind::Dense); - } + if (lhs.kind == StructuredSparsityKind::Unknown || + rhs.kind == StructuredSparsityKind::Unknown) + return StructuredSparsityPattern(StructuredSparsityKind::Unknown); - bool operator==(const StructuredSparsityPattern &other) const {} - - bool operator<=(const StructuredSparsityPattern &other) const { - if (kind == StructuredSparsityKind::Empty) - return true; - - if (other.kind == StructuredSparsityKind::Unknown) - return true; - - if (other.kind == StructuredSparsityKind::Empty) - return kind == StructuredSparsityKind::Empty; - - if (kind == StructuredSparsityKind::Unknown) - return other.kind == StructuredSparsityKind::Unknown; - - if (kind == other.kind) { - if (kind == StructuredSparsityKind::Band) { - return lowerBandwidth <= other.lowerBandwidth && - upperBandwidth <= other.upperBandwidth; - } - return true; - } - - if (kind == StructuredSparsityKind::Diagonal) { - return other.kind != StructuredSparsityKind::Empty; - } - - if (kind == StructuredSparsityKind::Bidiagonal) { - return other.kind == StructuredSparsityKind::Tridiagonal || - other.kind == StructuredSparsityKind::Band || - other.kind == StructuredSparsityKind::UpperTriangular || - other.kind == StructuredSparsityKind::Dense; - } - - if (kind == StructuredSparsityKind::Tridiagonal) { - return other.kind == StructuredSparsityKind::Band || - other.kind == StructuredSparsityKind::Dense; - } - - if (kind == StructuredSparsityKind::UpperTriangular || - kind == StructuredSparsityKind::LowerTriangular) { - if (other.kind == StructuredSparsityKind::Dense) - return true; - if (other.kind == StructuredSparsityKind::Band) { - if (kind == StructuredSparsityKind::UpperTriangular) { - return other.lowerBandwidth == 0; - } else { - return other.upperBandwidth == 0; - } - } - return false; - } - - if (kind == StructuredSparsityKind::Band) { - return other.kind == StructuredSparsityKind::Dense; - } - - if (kind == StructuredSparsityKind::Dense) { - return other.kind == StructuredSparsityKind::Dense || - other.kind == StructuredSparsityKind::Unknown; - } - - return false; + auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); + auto newPattern = StructuredSparsityPattern(lb, ub); + newPattern.refineKind(); + return newPattern; } private: void initializeBandwidths(); + void refineKind(); StructuredSparsityKind kind; int64_t lowerBandwidth; @@ -179,7 +117,6 @@ class ValueProperties { uint32_t getFlags() const { return flags; } - // partial ordering static ValueProperties meet(const ValueProperties &lhs, const ValueProperties &rhs) { return ValueProperties(lhs.flags & rhs.flags); @@ -190,14 +127,6 @@ class ValueProperties { return ValueProperties(lhs.flags | rhs.flags); } - bool operator==(const ValueProperties &other) const { - return flags == other.flags; - } - - bool operator<=(const ValueProperties &other) const { - return (flags & other.flags) == flags; - } - private: uint32_t flags; }; @@ -235,11 +164,6 @@ class StructuredMatrixType { ValueProperties::join(lhs.valueProperties, rhs.valueProperties)); } - bool operator==(const StructuredMatrixType &other) const { - return sparsityPattern == other.sparsityPattern && - valueProperties == other.valueProperties; - } - private: StructuredSparsityPattern sparsityPattern; ValueProperties valueProperties; From b5df7bbed6af7c96a7c81f31962f7a4fd13cdc3f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 15:28:54 -0600 Subject: [PATCH 03/11] feat: implement sparse dataflow analysis api --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 167 ++++++++++++++++++ .../jax/Analysis/StructuredMatrixAnalysis.h | 142 +++++++++------ test/lit_tests/structured/symmetric.mlir | 10 ++ 3 files changed, 268 insertions(+), 51 deletions(-) create mode 100644 test/lit_tests/structured/symmetric.mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index 2828199f8..643ea8ea2 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -102,9 +102,176 @@ void StructuredSparsityPattern::refineKind() { } } +// most specific pattern +StructuredSparsityPattern +StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs) { + if (lhs.kind == StructuredSparsityKind::Empty || + rhs.kind == StructuredSparsityKind::Empty) + return StructuredSparsityPattern(StructuredSparsityKind::Empty); + + if (lhs.kind == StructuredSparsityKind::Unknown) + return rhs; + if (rhs.kind == StructuredSparsityKind::Unknown) + return lhs; + + // for all other cases, we take the min of the bandwidths and refine + auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); + auto newPattern = StructuredSparsityPattern(lb, ub); + newPattern.refineKind(); + return newPattern; +} + +// least specific structure containing both +StructuredSparsityPattern +StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs) { + if (lhs.kind == StructuredSparsityKind::Empty) + return rhs; + if (rhs.kind == StructuredSparsityKind::Empty) + return lhs; + + if (lhs.kind == StructuredSparsityKind::Unknown || + rhs.kind == StructuredSparsityKind::Unknown) + return StructuredSparsityPattern(StructuredSparsityKind::Unknown); + + auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); + auto newPattern = StructuredSparsityPattern(lb, ub); + newPattern.refineKind(); + return newPattern; +} + +void StructuredSparsityPattern::print(raw_ostream &os) const { + switch (kind) { + case StructuredSparsityKind::Unknown: + os << "Unknown"; + break; + case StructuredSparsityKind::Dense: + os << "Dense"; + break; + case StructuredSparsityKind::Band: + os << "Band(" << lowerBandwidth << ", " << upperBandwidth << ")"; + break; + case StructuredSparsityKind::UpperTriangular: + os << "UpperTriangular"; + break; + case StructuredSparsityKind::UpperBidiagonal: + os << "UpperBidiagonal"; + break; + case StructuredSparsityKind::LowerTriangular: + os << "LowerTriangular"; + break; + case StructuredSparsityKind::LowerBidiagonal: + os << "LowerBidiagonal"; + break; + case StructuredSparsityKind::Tridiagonal: + os << "Tridiagonal"; + break; + case StructuredSparsityKind::Diagonal: + os << "Diagonal"; + break; + case StructuredSparsityKind::Empty: + os << "Empty"; + break; + } +} + //===----------------------------------------------------------------------===// // Value Properties Implementation //===----------------------------------------------------------------------===// +ValueProperties ValueProperties::meet(const ValueProperties &lhs, + const ValueProperties &rhs) { + return ValueProperties(lhs.flags & rhs.flags); +} + +ValueProperties ValueProperties::join(const ValueProperties &lhs, + const ValueProperties &rhs) { + return ValueProperties(lhs.flags | rhs.flags); +} + +void ValueProperties::print(raw_ostream &os) const { + os << "{"; + bool first = true; + auto add = [&](const char *s) { + if (!first) + os << ", "; + os << s; + first = false; + }; + + if (hasUnitDiagonal()) + add("UnitDiagonal"); + if (isSymmetric()) + add("Symmetric"); + if (isHermitian()) + add("Hermitian"); + if (isBroadcastedScalar()) + add("BroadcastedScalar"); + + os << "}"; +} + +//===----------------------------------------------------------------------===// +// Structured Matrix Type +//===----------------------------------------------------------------------===// + +StructuredMatrixType +StructuredMatrixType::meet(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs) { + return StructuredMatrixType( + StructuredSparsityPattern::meet(lhs.sparsityPattern, rhs.sparsityPattern), + ValueProperties::meet(lhs.valueProperties, rhs.valueProperties)); +} + +StructuredMatrixType +StructuredMatrixType::join(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs) { + return StructuredMatrixType( + StructuredSparsityPattern::join(lhs.sparsityPattern, rhs.sparsityPattern), + ValueProperties::join(lhs.valueProperties, rhs.valueProperties)); +} + +void StructuredMatrixType::print(raw_ostream &os) const { + os << "StructuredMatrixType("; + sparsityPattern.print(os); + os << " "; + valueProperties.print(os); + os << ")"; +} + +//===----------------------------------------------------------------------===// +// Lattice Element +//===----------------------------------------------------------------------===// + +void StructuredMatrixLattice::print(raw_ostream &os) const { + os << "StructuredMatrixLattice("; + value.print(os); + os << ")"; +} + +//===----------------------------------------------------------------------===// +// Dataflow Analysis +//===----------------------------------------------------------------------===// + +void StructuredMatrixAnalysis::setToEntryState( + StructuredMatrixLattice *lattice) { + lattice->setValue(StructuredMatrixType()); +} + +LogicalResult StructuredMatrixAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + llvm::errs() << "Visiting operation " << *op << "\n"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// Structure Originators +//===----------------------------------------------------------------------===// + } // namespace structure_analysis } // namespace mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index 67b489bb3..32433317f 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -1,3 +1,5 @@ +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + #include #include @@ -40,43 +42,21 @@ class StructuredSparsityPattern { refineKind(); } - // most precise of lhs and rhs static StructuredSparsityPattern meet(const StructuredSparsityPattern &lhs, - const StructuredSparsityPattern &rhs) { - if (lhs.kind == StructuredSparsityKind::Empty || - rhs.kind == StructuredSparsityKind::Empty) - return StructuredSparsityPattern(StructuredSparsityKind::Empty); - - if (lhs.kind == StructuredSparsityKind::Unknown) - return rhs; - if (rhs.kind == StructuredSparsityKind::Unknown) - return lhs; - - // for all other cases, we take the min of the bandwidths and refine - auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); - auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); - auto newPattern = StructuredSparsityPattern(lb, ub); - newPattern.refineKind(); - return newPattern; - } + const StructuredSparsityPattern &rhs); - // least precise of lhs and rhs static StructuredSparsityPattern join(const StructuredSparsityPattern &lhs, - const StructuredSparsityPattern &rhs) { - if (lhs.kind == StructuredSparsityKind::Empty) - return rhs; - if (rhs.kind == StructuredSparsityKind::Empty) - return lhs; - - if (lhs.kind == StructuredSparsityKind::Unknown || - rhs.kind == StructuredSparsityKind::Unknown) - return StructuredSparsityPattern(StructuredSparsityKind::Unknown); - - auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); - auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); - auto newPattern = StructuredSparsityPattern(lb, ub); - newPattern.refineKind(); - return newPattern; + const StructuredSparsityPattern &rhs); + + bool operator==(const StructuredSparsityPattern &other) const { + return kind == other.kind && lowerBandwidth == other.lowerBandwidth && + upperBandwidth == other.upperBandwidth; + } + + void print(raw_ostream &os) const; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; } private: @@ -96,6 +76,7 @@ enum class ValueProperty { UnitDiagonal = 1 << 0, Symmetric = 1 << 1, Hermitian = 1 << 2, + BroadcastedScalar = 1 << 3, }; class ValueProperties { @@ -114,17 +95,26 @@ class ValueProperties { bool hasUnitDiagonal() const { return has(ValueProperty::UnitDiagonal); } bool isSymmetric() const { return has(ValueProperty::Symmetric); } bool isHermitian() const { return has(ValueProperty::Hermitian); } + bool isBroadcastedScalar() const { + return has(ValueProperty::BroadcastedScalar); + } + + void print(raw_ostream &os) const; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } uint32_t getFlags() const { return flags; } static ValueProperties meet(const ValueProperties &lhs, - const ValueProperties &rhs) { - return ValueProperties(lhs.flags & rhs.flags); - } + const ValueProperties &rhs); static ValueProperties join(const ValueProperties &lhs, - const ValueProperties &rhs) { - return ValueProperties(lhs.flags | rhs.flags); + const ValueProperties &rhs); + + bool operator==(const ValueProperties &other) const { + return flags == other.flags; } private: @@ -147,27 +137,77 @@ class StructuredMatrixType { } const ValueProperties &getProperties() const { return valueProperties; } - // partial ordering static StructuredMatrixType meet(const StructuredMatrixType &lhs, - const StructuredMatrixType &rhs) { - return StructuredMatrixType( - StructuredSparsityPattern::meet(lhs.sparsityPattern, - rhs.sparsityPattern), - ValueProperties::meet(lhs.valueProperties, rhs.valueProperties)); - } + const StructuredMatrixType &rhs); static StructuredMatrixType join(const StructuredMatrixType &lhs, - const StructuredMatrixType &rhs) { - return StructuredMatrixType( - StructuredSparsityPattern::join(lhs.sparsityPattern, - rhs.sparsityPattern), - ValueProperties::join(lhs.valueProperties, rhs.valueProperties)); + const StructuredMatrixType &rhs); + + bool operator==(const StructuredMatrixType &other) const { + return sparsityPattern == other.sparsityPattern && + valueProperties == other.valueProperties; } + void print(raw_ostream &os) const; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } + + // TODO: propagation rules probably goes in here + + // TODO: implement queries that check both the sparsity pattern and value + // properties and return specific matrix kinds + private: StructuredSparsityPattern sparsityPattern; ValueProperties valueProperties; }; +//===----------------------------------------------------------------------===// +// Lattice Element +//===----------------------------------------------------------------------===// + +class StructuredMatrixLattice : public dataflow::AbstractSparseLattice { +public: + using AbstractSparseLattice::AbstractSparseLattice; + + ChangeResult meet(const AbstractSparseLattice &rhs) override; + ChangeResult join(const AbstractSparseLattice &rhs) override; + + void print(raw_ostream &os) const override; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } + + const StructuredMatrixType &getValue() const { return value; } + void setValue(const StructuredMatrixType &v) { value = v; } + +private: + StructuredMatrixType value; +}; + +//===----------------------------------------------------------------------===// +// Dataflow Analysis +//===----------------------------------------------------------------------===// + +class StructuredMatrixAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(StructuredMatrixLattice *lattice) override; + + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; +}; + +//===----------------------------------------------------------------------===// +// Structure Originators +//===----------------------------------------------------------------------===// + } // namespace structure_analysis } // namespace mlir diff --git a/test/lit_tests/structured/symmetric.mlir b/test/lit_tests/structured/symmetric.mlir new file mode 100644 index 000000000..4ddcf31a5 --- /dev/null +++ b/test/lit_tests/structured/symmetric.mlir @@ -0,0 +1,10 @@ +module { + func.func @symmetric(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32> + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.add %0, %arg0 : tensor<2x2xf32> + %2 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + %3 = stablehlo.add %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32> + } +} From 60f38f2f834df39b49cd411f4ff737f1da8cdde0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 15:59:03 -0600 Subject: [PATCH 04/11] feat: add stub pass --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 39 +++++++++++++ .../jax/Analysis/StructuredMatrixAnalysis.h | 14 +++++ src/enzyme_ad/jax/BUILD | 1 + src/enzyme_ad/jax/Passes/Passes.td | 8 +++ .../jax/Passes/StructuredMatrixSimplify.cpp | 55 +++++++++++++++++++ test/lit_tests/structured/symmetric.mlir | 2 + 6 files changed, 119 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index 643ea8ea2..c776fab95 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -21,6 +21,11 @@ namespace structure_analysis { // Structured Sparsity Pattern Implementation //===----------------------------------------------------------------------===// +StructuredSparsityPattern::StructuredSparsityPattern(Value v) { + llvm::errs() << "TODO: structured sparsity pattern not implemented for " << v + << "\n"; +} + void StructuredSparsityPattern::initializeBandwidths() { switch (kind) { case StructuredSparsityKind::Unknown: @@ -182,6 +187,10 @@ void StructuredSparsityPattern::print(raw_ostream &os) const { // Value Properties Implementation //===----------------------------------------------------------------------===// +ValueProperties::ValueProperties(Value v) { + llvm::errs() << "TODO: value properties not implemented for " << v << "\n"; +} + ValueProperties ValueProperties::meet(const ValueProperties &lhs, const ValueProperties &rhs) { return ValueProperties(lhs.flags & rhs.flags); @@ -246,6 +255,36 @@ void StructuredMatrixType::print(raw_ostream &os) const { // Lattice Element //===----------------------------------------------------------------------===// +ChangeResult StructuredMatrixLattice::meet(const AbstractSparseLattice &rhs) { + const auto *rhsStruct = + reinterpret_cast(&rhs); + return meet(*rhsStruct); +} + +ChangeResult StructuredMatrixLattice::meet(StructuredMatrixLattice rhs) { + auto newValue = StructuredMatrixType::meet(getValue(), rhs.getValue()); + if (getValue() == newValue) + return ChangeResult::NoChange; + + setValue(newValue); + return ChangeResult::Change; +} + +ChangeResult StructuredMatrixLattice::join(const AbstractSparseLattice &rhs) { + const auto *rhsStruct = + reinterpret_cast(&rhs); + return join(*rhsStruct); +} + +ChangeResult StructuredMatrixLattice::join(StructuredMatrixLattice rhs) { + auto newValue = StructuredMatrixType::join(getValue(), rhs.getValue()); + if (getValue() == newValue) + return ChangeResult::NoChange; + + setValue(newValue); + return ChangeResult::Change; +} + void StructuredMatrixLattice::print(raw_ostream &os) const { os << "StructuredMatrixLattice("; value.print(os); diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index 32433317f..3b0fdd0b6 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -36,6 +36,8 @@ class StructuredSparsityPattern { initializeBandwidths(); } + StructuredSparsityPattern(Value v); + StructuredSparsityPattern(int64_t lowerBandwidth, int64_t upperBandwidth) : kind(StructuredSparsityKind::Band), lowerBandwidth(lowerBandwidth), upperBandwidth(upperBandwidth) { @@ -84,6 +86,8 @@ class ValueProperties { ValueProperties() = default; explicit ValueProperties(uint32_t flags) : flags(flags) {} + ValueProperties(Value v); + void set(ValueProperty property) { flags |= static_cast(property); } void clear(ValueProperty property) { flags &= ~static_cast(property); @@ -132,6 +136,10 @@ class StructuredMatrixType { ValueProperties valueProperties) : sparsityPattern(sparsityPattern), valueProperties(valueProperties) {} + StructuredMatrixType(Value v) + : StructuredMatrixType(StructuredSparsityPattern(v), ValueProperties(v)) { + } + const StructuredSparsityPattern &getSparsityPattern() const { return sparsityPattern; } @@ -172,8 +180,14 @@ class StructuredMatrixLattice : public dataflow::AbstractSparseLattice { public: using AbstractSparseLattice::AbstractSparseLattice; + StructuredMatrixLattice(Value v) + : AbstractSparseLattice(v), value(StructuredMatrixType(v)) {} + ChangeResult meet(const AbstractSparseLattice &rhs) override; + ChangeResult meet(StructuredMatrixLattice rhs); + ChangeResult join(const AbstractSparseLattice &rhs) override; + ChangeResult join(StructuredMatrixLattice rhs); void print(raw_ostream &os) const override; raw_ostream &operator<<(raw_ostream &os) const { diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 20ba941ce..fae3e850f 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -903,6 +903,7 @@ cc_library( ":RaisingTransformOpsIncGen", ":RaisingTransformPatternsIncGen", ":StablehloOptPatternsIncGen", + ":StructuredMatrixAnalysis", ":TesseraDialectIncGen", ":TesseraOpsIncGen", ":TritonExtDialect", diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5bf009201..b872474b4 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1077,4 +1077,12 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { ]; } +def StructuredMatrixSimplifyPass : Pass<"structured-matrix-simplify", "ModuleOp"> { + let summary = "Simplify structured matrix operations"; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "func::FuncDialect", + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp new file mode 100644 index 000000000..4c1ddfa58 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp @@ -0,0 +1,55 @@ +#include "src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "stablehlo/dialect/StablehloOps.h" + +#define DEBUG_TYPE "structured-matrix-simplify" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_STRUCTUREDMATRIXSIMPLIFYPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; +using namespace mlir::enzyme; +using namespace mlir::structure_analysis; + +namespace { + +class StructuredMatrixSimplifyPass + : public enzyme::impl::StructuredMatrixSimplifyPassBase< + StructuredMatrixSimplifyPass> { +public: + using Base::Base; + + void runOnOperation() override { + DataFlowSolver solver; + + solver.load(); + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(getOperation()))) { + return signalPassFailure(); + } + + // TODO: do things here + } +}; + +} // namespace diff --git a/test/lit_tests/structured/symmetric.mlir b/test/lit_tests/structured/symmetric.mlir index 4ddcf31a5..089d0c073 100644 --- a/test/lit_tests/structured/symmetric.mlir +++ b/test/lit_tests/structured/symmetric.mlir @@ -1,3 +1,5 @@ +// RUN: enzymexlamlir-opt --structured-matrix-simplify %s | FileCheck %s + module { func.func @symmetric(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32> From 806ff4ca9d342dfa993e3080843ad4fb836bc75b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Nov 2025 16:51:05 -0600 Subject: [PATCH 05/11] feat: value properties check --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 138 +++++++++++++++++- .../jax/Analysis/StructuredMatrixAnalysis.h | 41 +++++- .../jax/Passes/StructuredMatrixSimplify.cpp | 2 + test/lit_tests/structured/symmetric.mlir | 4 +- 4 files changed, 180 insertions(+), 5 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index c776fab95..a3f906458 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -6,6 +6,8 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -22,8 +24,16 @@ namespace structure_analysis { //===----------------------------------------------------------------------===// StructuredSparsityPattern::StructuredSparsityPattern(Value v) { + if (auto blockArg = dyn_cast(v)) { + // TODO: If block arg is annotated with a pattern we should parse that + setUnknown(); // be pessimistic by default + return; + } + llvm::errs() << "TODO: structured sparsity pattern not implemented for " << v << "\n"; + setUnknown(); + return; } void StructuredSparsityPattern::initializeBandwidths() { @@ -188,7 +198,126 @@ void StructuredSparsityPattern::print(raw_ostream &os) const { //===----------------------------------------------------------------------===// ValueProperties::ValueProperties(Value v) { - llvm::errs() << "TODO: value properties not implemented for " << v << "\n"; + if (auto blockArg = dyn_cast(v)) { + // TODO: If block arg is annotated with a pattern we should parse that + setFlags(0); // be pessimistic by default + return; + } + + auto vTy = cast(v.getType()); + if (!vTy.hasStaticShape() || vTy.getRank() != 2) + return; + auto vShape = vTy.getShape(); + if (vShape[0] != vShape[1]) // TODO: should we allow rectangular matrices? + return; + + DenseElementsAttr denseAttr; + if (matchPattern(v, m_Constant(&denseAttr))) { + auto props = getPropertiesFromDenseAttr(denseAttr); + setFlags(props.getFlags()); + llvm::errs() << "v: " << v << " properties: "; + this->print(llvm::errs()); + llvm::errs() << "\n"; + return; + } + + // TODO: symmetric checks Utils.cpp:688 + + // TODO: broadcasted scalar + + // TODO: unit diagonal + // - iota scatter with constant + + llvm::errs() << "(Not implemented) v: " << v << " properties: "; + this->print(llvm::errs()); + llvm::errs() << "\n"; + return; +} + +ValueProperties +ValueProperties::getPropertiesFromDenseAttr(DenseElementsAttr attr) { + ValueProperties props; + + if (attr.isSplat()) { + auto val = attr.getSplatValue(); + if (utils::isOne(val)) + props.set(ValueProperty::UnitDiagonal); + + props.set(ValueProperty::BroadcastedScalar); + props.set(ValueProperty::Symmetric); + props.set(ValueProperty::Hermitian); + return props; + } + + auto type = dyn_cast(attr.getType()); + if (!type) + return props; + + auto shape = type.getShape(); + int64_t nrows = shape[0]; + int64_t ncols = shape[1]; + if (nrows != ncols) + return props; + + if (isUnitDiagonal(attr, nrows, ncols)) + props.set(ValueProperty::UnitDiagonal); + + auto [isSymmetric, isHermitian] = isSymmetricOrHermitian(attr, nrows, ncols); + if (isSymmetric) + props.set(ValueProperty::Symmetric); + if (isHermitian) + props.set(ValueProperty::Hermitian); + + return props; +} + +template +bool isUnitDiagonalImpl(DenseElementsAttr attr, int64_t nrows, int64_t ncols) { + auto values = attr.getValues().begin(); + for (int64_t i = 0; i < std::min(nrows, ncols); i++) { + if (!utils::isOne(values[i])) + return false; + } + return true; +} + +bool ValueProperties::isUnitDiagonal(DenseElementsAttr attr, int64_t nrows, + int64_t ncols) { + if (isa(attr.getElementType())) { + return isUnitDiagonalImpl(attr, nrows, ncols); + } else if (isa(attr.getElementType())) { + return isUnitDiagonalImpl(attr, nrows, ncols); + } + return false; +} + +template +std::tuple isSymmetricOrHermitianImpl(DenseElementsAttr attr, + int64_t nrows, + int64_t ncols) { + auto values = attr.getValues().begin(); + for (int64_t i = 0; i < nrows; i++) { + for (int64_t j = i + 1; j < ncols; j++) { + auto a = *(values + i * ncols + j); + auto b = *(values + j * ncols + i); + if (!utils::areEqual(a, b)) { + return {false, false}; // TODO: check for hermitian + } + } + } + + return {true, false}; // TODO: check for hermitian +} + +std::tuple +ValueProperties::isSymmetricOrHermitian(DenseElementsAttr attr, int64_t nrows, + int64_t ncols) { + if (isa(attr.getElementType())) { + return isSymmetricOrHermitianImpl(attr, nrows, ncols); + } else if (isa(attr.getElementType())) { + return isSymmetricOrHermitianImpl(attr, nrows, ncols); + } + return {false, false}; } ValueProperties ValueProperties::meet(const ValueProperties &lhs, @@ -303,7 +432,14 @@ void StructuredMatrixAnalysis::setToEntryState( LogicalResult StructuredMatrixAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { + llvm::errs() << "Visiting operation " << *op << "\n"; + for (auto operand : operands) { + llvm::errs() << " operand: "; + operand->getValue().print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "\n"; return success(); } diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index 3b0fdd0b6..e240e4c18 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -1,3 +1,5 @@ +#pragma once + #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include @@ -6,6 +8,25 @@ namespace mlir { namespace structure_analysis { +namespace utils { + +static bool isOne(APInt v) { return v.isOne(); } +static bool isOne(APFloat v) { return v.isExactlyValue(1.0); } +static bool isOne(Attribute v) { + if (auto intAttr = dyn_cast(v)) + return isOne(intAttr.getValue()); + if (auto floatAttr = dyn_cast(v)) + return isOne(floatAttr.getValue()); + return false; +} + +static bool areEqual(APInt a, APInt b) { return a == b; } +static bool areEqual(APFloat a, APFloat b) { + return a.compare(b) == llvm::APFloat::cmpEqual; +} + +} // namespace utils + //===----------------------------------------------------------------------===// // Structured Sparsity Pattern Implementation //===----------------------------------------------------------------------===// @@ -28,8 +49,8 @@ enum class StructuredSparsityKind { class StructuredSparsityPattern { public: StructuredSparsityPattern() - : kind(StructuredSparsityKind::Unknown), lowerBandwidth(0), - upperBandwidth(0) {} + : kind(StructuredSparsityKind::Unknown), lowerBandwidth(-1), + upperBandwidth(-1) {} explicit StructuredSparsityPattern(StructuredSparsityKind kind) : kind(kind), lowerBandwidth(-1), upperBandwidth(-1) { @@ -65,6 +86,12 @@ class StructuredSparsityPattern { void initializeBandwidths(); void refineKind(); + void setUnknown() { + kind = StructuredSparsityKind::Unknown; + lowerBandwidth = -1; + upperBandwidth = -1; + } + StructuredSparsityKind kind; int64_t lowerBandwidth; int64_t upperBandwidth; @@ -110,6 +137,7 @@ class ValueProperties { } uint32_t getFlags() const { return flags; } + void setFlags(uint32_t f) { flags = f; } static ValueProperties meet(const ValueProperties &lhs, const ValueProperties &rhs); @@ -122,7 +150,14 @@ class ValueProperties { } private: - uint32_t flags; + static ValueProperties getPropertiesFromDenseAttr(DenseElementsAttr attr); + + static bool isUnitDiagonal(DenseElementsAttr attr, int64_t nrows, + int64_t ncols); + static std::tuple + isSymmetricOrHermitian(DenseElementsAttr, int64_t nrows, int64_t ncols); + + uint32_t flags = 0; }; //===----------------------------------------------------------------------===// diff --git a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp index 4c1ddfa58..4b31eeaa5 100644 --- a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp +++ b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp @@ -48,6 +48,8 @@ class StructuredMatrixSimplifyPass return signalPassFailure(); } + // TODO: annotate the IR with the properties for later usage (use an option) + // TODO: do things here } }; diff --git a/test/lit_tests/structured/symmetric.mlir b/test/lit_tests/structured/symmetric.mlir index 089d0c073..6943ecda7 100644 --- a/test/lit_tests/structured/symmetric.mlir +++ b/test/lit_tests/structured/symmetric.mlir @@ -2,7 +2,9 @@ module { func.func @symmetric(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %cst = stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32> + %c0 = stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32> + %c1 = stablehlo.constant dense<1.000000e+00> : tensor<2x2xf32> + %cst = stablehlo.add %c0, %c1 : tensor<2x2xf32> %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> %1 = stablehlo.add %0, %arg0 : tensor<2x2xf32> %2 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> From eb7c483e6c65cbb2e03cf9a155e8616e18fc22eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Nov 2025 10:41:51 -0600 Subject: [PATCH 06/11] feat: broadcast check --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 30 +++++++++++-------- .../jax/Analysis/StructuredMatrixAnalysis.h | 4 --- .../symmetric.mlir | 0 3 files changed, 18 insertions(+), 16 deletions(-) rename test/lit_tests/{structured => structured_tensors}/symmetric.mlir (100%) diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index a3f906458..00328b641 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -13,6 +13,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/Support/raw_ostream.h" +#include "stablehlo/dialect/StablehloOps.h" + using namespace mlir; using namespace mlir::dataflow; @@ -75,9 +77,6 @@ void StructuredSparsityPattern::initializeBandwidths() { } void StructuredSparsityPattern::refineKind() { - if (kind != StructuredSparsityKind::Band) - return; - if (lowerBandwidth == 0) { if (upperBandwidth == 0) { kind = StructuredSparsityKind::Diagonal; @@ -215,15 +214,26 @@ ValueProperties::ValueProperties(Value v) { if (matchPattern(v, m_Constant(&denseAttr))) { auto props = getPropertiesFromDenseAttr(denseAttr); setFlags(props.getFlags()); - llvm::errs() << "v: " << v << " properties: "; - this->print(llvm::errs()); - llvm::errs() << "\n"; return; } - // TODO: symmetric checks Utils.cpp:688 + auto defOp = v.getDefiningOp(); + if (!defOp) + return; + + // comm_op(A, A^T) will always be symmetric - // TODO: broadcasted scalar + // A x A^T will always be symmetric + + if (auto bcastOp = dyn_cast(defOp)) { + auto operand = bcastOp.getOperand(); + if (cast(operand.getType()).getRank() == 0) { // bcast(scalar) + if (matchPattern(operand, m_One())) // bcast(1) + set(ValueProperty::UnitDiagonal); + set(ValueProperty::BroadcastedScalar); + set(ValueProperty::Symmetric); + } + } // TODO: unit diagonal // - iota scatter with constant @@ -444,9 +454,5 @@ LogicalResult StructuredMatrixAnalysis::visitOperation( return success(); } -//===----------------------------------------------------------------------===// -// Structure Originators -//===----------------------------------------------------------------------===// - } // namespace structure_analysis } // namespace mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index e240e4c18..589ec8aa3 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -254,9 +254,5 @@ class StructuredMatrixAnalysis ArrayRef results) override; }; -//===----------------------------------------------------------------------===// -// Structure Originators -//===----------------------------------------------------------------------===// - } // namespace structure_analysis } // namespace mlir diff --git a/test/lit_tests/structured/symmetric.mlir b/test/lit_tests/structured_tensors/symmetric.mlir similarity index 100% rename from test/lit_tests/structured/symmetric.mlir rename to test/lit_tests/structured_tensors/symmetric.mlir From 188916a3fa7d74d6dba44593f642d7ee797de800 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 23 Nov 2025 17:44:06 -0600 Subject: [PATCH 07/11] feat: annotate IR with structure properties --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 31 ++++- .../jax/Analysis/StructuredMatrixAnalysis.h | 5 + src/enzyme_ad/jax/BUILD | 25 +--- src/enzyme_ad/jax/Dialect/Dialect.cpp | 117 ++++++++++++++++++ src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td | 73 +++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 1 + .../jax/Passes/StructuredMatrixSimplify.cpp | 66 +++++++++- .../structured_tensors/newton_schulz.mlir | 29 +++++ 8 files changed, 319 insertions(+), 28 deletions(-) create mode 100644 test/lit_tests/structured_tensors/newton_schulz.mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index 00328b641..148ff55a6 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -1,4 +1,5 @@ #include "src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h" +#include "src/enzyme_ad/jax/Utils.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" @@ -221,9 +222,33 @@ ValueProperties::ValueProperties(Value v) { if (!defOp) return; + // check that transpose dimensions are [1,0] + auto isTrueTranspose = [](stablehlo::TransposeOp tOp) -> bool { + auto perm = tOp.getPermutation(); + return perm.size() == 2 && perm[0] == 1 && perm[1] == 0; + }; + // comm_op(A, A^T) will always be symmetric + if (stablehlo::hasTraitElementwise(defOp) && + (defOp->hasTrait() || + defOp->hasTrait())) { + auto lhs = defOp->getOperand(0); + auto rhs = defOp->getOperand(1); + + if (auto rhsT = rhs.getDefiningOp()) { + if (isTrueTranspose(rhsT) && lhs == rhsT.getOperand()) { + set(ValueProperty::Symmetric); + } + } + + if (auto lhsT = lhs.getDefiningOp()) { + if (isTrueTranspose(lhsT) && rhs == lhsT.getOperand()) { + set(ValueProperty::Symmetric); + } + } + } - // A x A^T will always be symmetric + // TODO: A x A^T will always be symmetric if (auto bcastOp = dyn_cast(defOp)) { auto operand = bcastOp.getOperand(); @@ -232,15 +257,13 @@ ValueProperties::ValueProperties(Value v) { set(ValueProperty::UnitDiagonal); set(ValueProperty::BroadcastedScalar); set(ValueProperty::Symmetric); + return; } } // TODO: unit diagonal // - iota scatter with constant - llvm::errs() << "(Not implemented) v: " << v << " properties: "; - this->print(llvm::errs()); - llvm::errs() << "\n"; return; } diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index 589ec8aa3..4ffbe9fc9 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -2,6 +2,8 @@ #include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" + #include #include @@ -65,6 +67,9 @@ class StructuredSparsityPattern { refineKind(); } + int64_t getLowerBandwidth() const { return lowerBandwidth; } + int64_t getUpperBandwidth() const { return upperBandwidth; } + static StructuredSparsityPattern meet(const StructuredSparsityPattern &lhs, const StructuredSparsityPattern &rhs); diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index fae3e850f..cf247db58 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -836,31 +836,10 @@ cc_library( ], ) -cc_library( - name = "StructuredMatrixAnalysis", - srcs = ["Analysis/StructuredMatrixAnalysis.cpp"], - hdrs = ["Analysis/StructuredMatrixAnalysis.h"], - deps = [ - # MLIR Core - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - - # MLIR Analysis - "@llvm-project//mlir:Analysis", - - # MLIR Dialects - "@stablehlo//:stablehlo_ops", - - # LLVM Support - "@llvm-project//llvm:Support", - ], -) - cc_library( name = "XLADerivatives", srcs = glob([ + "Analysis/*.cpp", "Implementations/*.cpp", "Passes/*.cpp", "Dialect/*.cpp", @@ -870,6 +849,7 @@ cc_library( "Utils.cpp", ], hdrs = glob([ + "Analysis/*.h", "Implementations/*.h", "Passes/*.h", "Dialect/*.h", @@ -903,7 +883,6 @@ cc_library( ":RaisingTransformOpsIncGen", ":RaisingTransformPatternsIncGen", ":StablehloOptPatternsIncGen", - ":StructuredMatrixAnalysis", ":TesseraDialectIncGen", ":TesseraOpsIncGen", ":TritonExtDialect", diff --git a/src/enzyme_ad/jax/Dialect/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Dialect.cpp index d9e0e1375..3c4ed5537 100644 --- a/src/enzyme_ad/jax/Dialect/Dialect.cpp +++ b/src/enzyme_ad/jax/Dialect/Dialect.cpp @@ -97,6 +97,123 @@ struct EnzymeXLADialectInlinerInterface : public DialectInlinerInterface { } // namespace +void StructuredSparsityAttr::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + + auto pattern = getPattern(); + auto kind = pattern.getKind(); + + // Skip printing kind for Unknown and Empty + bool printKind = (kind != StructuredSparsityKind::Unknown && + kind != StructuredSparsityKind::Empty); + + if (printKind) { + printer << stringifyStructuredSparsityKind(kind); + + // Print bandwidth only for Band + if (kind == StructuredSparsityKind::Band) { + printer << " [" << pattern.getLowerBandwidth() << ", " + << pattern.getUpperBandwidth() << "]"; + } + } + + // Print value properties as a set + auto props = getValueProperties(); + if (!props.empty()) { + if (printKind) { + printer << ", "; + } + printer << "{"; + llvm::interleaveComma(props, printer, [&](StructuredValueProperty prop) { + printer << stringifyStructuredValueProperty(prop); + }); + printer << "}"; + } + + printer << ">"; +} + +::mlir::Attribute StructuredSparsityAttr::parse(::mlir::AsmParser &parser, + ::mlir::Type type) { + if (parser.parseLess()) + return {}; + + StructuredSparsityKind kind = StructuredSparsityKind::Unknown; + int64_t lowerBandwidth = -1; + int64_t upperBandwidth = -1; + llvm::SmallVector valueProperties; + + // Check if we start with a brace (properties only, no kind) + if (parser.parseOptionalLBrace().failed()) { + // Try to parse the kind + llvm::StringRef kindStr; + if (succeeded(parser.parseOptionalKeyword(&kindStr))) { + auto kindOpt = symbolizeStructuredSparsityKind(kindStr); + if (!kindOpt) { + parser.emitError(parser.getCurrentLocation(), "invalid sparsity kind: ") + << kindStr; + return {}; + } + kind = *kindOpt; + + // If Band, parse bandwidth bounds + if (kind == StructuredSparsityKind::Band) { + if (parser.parseLSquare() || parser.parseInteger(lowerBandwidth) || + parser.parseComma() || parser.parseInteger(upperBandwidth) || + parser.parseRSquare()) { + return {}; + } + } + + // Check for comma before properties + parser.parseOptionalComma(); + } + + // Try to parse value properties set + if (succeeded(parser.parseOptionalLBrace())) { + // Fall through to parse properties + } else { + // No properties + if (parser.parseGreater()) + return {}; + + auto pattern = StructuredSparsityPatternAttr::get( + parser.getContext(), kind, lowerBandwidth, upperBandwidth); + return StructuredSparsityAttr::get(parser.getContext(), pattern, + valueProperties); + } + } + + // Parse properties + if (!parser.parseOptionalRBrace().succeeded()) { + do { + llvm::StringRef propStr; + if (parser.parseKeyword(&propStr)) + return {}; + + auto propOpt = symbolizeStructuredValueProperty(propStr); + if (!propOpt) { + parser.emitError(parser.getCurrentLocation(), + "invalid value property: ") + << propStr; + return {}; + } + valueProperties.push_back(*propOpt); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRBrace()) + return {}; + } + + if (parser.parseGreater()) + return {}; + + auto pattern = StructuredSparsityPatternAttr::get( + parser.getContext(), kind, lowerBandwidth, upperBandwidth); + return StructuredSparsityAttr::get(parser.getContext(), pattern, + valueProperties); +} + void EnzymeXLADialect::initialize() { addInterfaces(); addOperations< diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td index 11bf8ca25..b130c84d6 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td @@ -133,4 +133,77 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult", def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr; +def EnzymeXLA_StructuredSparsityKind : I32EnumAttr<"StructuredSparsityKind", + "Kind of sparsity pattern", + [ + I32EnumAttrCase<"Unknown", 0>, + I32EnumAttrCase<"Dense", 1>, + I32EnumAttrCase<"Band", 2>, + I32EnumAttrCase<"UpperTriangular", 3>, + I32EnumAttrCase<"UpperBidiagonal", 4>, + I32EnumAttrCase<"LowerTriangular", 5>, + I32EnumAttrCase<"LowerBidiagonal", 6>, + I32EnumAttrCase<"Tridiagonal", 7>, + I32EnumAttrCase<"Diagonal", 8>, + I32EnumAttrCase<"Empty", 9> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_StructuredSparsityKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def EnzymeXLA_StructuredSparsityPattern : AttrDef { + let summary = "Structured Sparsity Pattern"; + + let parameters = (ins + "::mlir::enzymexla::StructuredSparsityKind":$kind, + "int64_t":$lowerBandwidth, + "int64_t":$upperBandwidth + ); + + let assemblyFormat = [{ + `<` $kind ` ` `[` $lowerBandwidth `,` ` ` $upperBandwidth `]` `>` + }]; + + let mnemonic = "structured_sparsity_pattern"; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_StructuredValueProperty : I32EnumAttr<"StructuredValueProperty", + "Value properties", + [ + I32EnumAttrCase<"UnitDiagonal", 0>, + I32EnumAttrCase<"Symmetric", 1>, + I32EnumAttrCase<"Hermitian", 2>, + I32EnumAttrCase<"BroadcastedScalar", 3> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_StructuredValuePropertyAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def EnzymeXLA_StructuredSparsity : AttrDef { + let summary = "Structured Sparsity"; + let cppNamespace = "::mlir::enzymexla"; + + let parameters = (ins + EnzymeXLA_StructuredSparsityPattern:$pattern, + ArrayRefParameter<"StructuredValueProperty">:$valueProperties + ); + + let mnemonic = "structured_sparsity"; + + let hasCustomAssemblyFormat = 1; +} + #endif // ENZYMEXLA_ATTRS diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index b872474b4..b29fa212f 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1081,6 +1081,7 @@ def StructuredMatrixSimplifyPass : Pass<"structured-matrix-simplify", "ModuleOp" let summary = "Simplify structured matrix operations"; let dependentDialects = [ "stablehlo::StablehloDialect", + "enzymexla::EnzymeXLADialect", "func::FuncDialect", ]; } diff --git a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp index 4b31eeaa5..3c5fe8b73 100644 --- a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp +++ b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp @@ -13,6 +13,7 @@ #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "stablehlo/dialect/StablehloOps.h" #define DEBUG_TYPE "structured-matrix-simplify" @@ -48,7 +49,70 @@ class StructuredMatrixSimplifyPass return signalPassFailure(); } - // TODO: annotate the IR with the properties for later usage (use an option) + auto mod = getOperation(); + + // TODO: make IR annotation optional via an option + mod->walk([&](Operation *op) { + SmallVector structuredSparsityAttrs; + bool anyKnown = false; + for (auto result : op->getResults()) { + auto *state = + solver.lookupState( + result); + if (!state) { + structuredSparsityAttrs.push_back( + enzymexla::StructuredSparsityAttr::get( + mod.getContext(), + enzymexla::StructuredSparsityPatternAttr::get( + mod.getContext(), + enzymexla::StructuredSparsityKind::Unknown, -1, -1), + SmallVector())); + continue; + } + + anyKnown = true; + + // TODO: get structured sparsity kind + auto structuredSparsityKind = + enzymexla::StructuredSparsityPatternAttr::get( + mod.getContext(), enzymexla::StructuredSparsityKind::Unknown, + state->getValue().getSparsityPattern().getLowerBandwidth(), + state->getValue().getSparsityPattern().getUpperBandwidth()); + + SmallVector + structuredValueProperties; + auto valueProperties = state->getValue().getProperties(); + if (valueProperties.hasUnitDiagonal()) { + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::UnitDiagonal); + } + if (valueProperties.isSymmetric()) { + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::Symmetric); + } + if (valueProperties.isHermitian()) { + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::Hermitian); + } + if (valueProperties.isBroadcastedScalar()) { + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::BroadcastedScalar); + } + + auto structuredSparsity = enzymexla::StructuredSparsityAttr::get( + mod.getContext(), structuredSparsityKind, + structuredValueProperties); + + structuredSparsityAttrs.push_back(structuredSparsity); + } + + if (anyKnown) { + op->setAttr("structured_sparsity", + ArrayAttr::get(mod.getContext(), structuredSparsityAttrs)); + } + + return WalkResult::advance(); + }); // TODO: do things here } diff --git a/test/lit_tests/structured_tensors/newton_schulz.mlir b/test/lit_tests/structured_tensors/newton_schulz.mlir new file mode 100644 index 000000000..5d804cc80 --- /dev/null +++ b/test/lit_tests/structured_tensors/newton_schulz.mlir @@ -0,0 +1,29 @@ +// RUN: enzymexlamlir-opt --structured-matrix-simplify %s | FileCheck %s + +module { + func.func @main(%arg0: tensor<5x4xf32>) -> tensor<5x4xf32> { + %c = stablehlo.constant dense<0> : tensor + %c_0 = stablehlo.constant dense<5> : tensor + %c_1 = stablehlo.constant dense<1> : tensor + %cst = stablehlo.constant dense<2.031500e+00> : tensor<4x4xf32> + %cst_2 = stablehlo.constant dense<-4.775000e+00> : tensor<4x4xf32> + %cst_3 = stablehlo.constant dense<3.444500e+00> : tensor<5x4xf32> + %0:2 = stablehlo.while(%iterArg = %c, %iterArg_4 = %arg0) : tensor, tensor<5x4xf32> + cond { + %1 = stablehlo.compare LT, %iterArg, %c_0 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.add %iterArg, %c_1 : tensor + %2 = stablehlo.dot_general %iterArg_4, %iterArg_4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<4x4xf32> + %3 = stablehlo.multiply %cst_2, %2 : tensor<4x4xf32> + %4 = stablehlo.multiply %cst, %2 : tensor<4x4xf32> + %5 = stablehlo.dot_general %4, %2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %6 = stablehlo.add %3, %5 : tensor<4x4xf32> + %7 = stablehlo.multiply %cst_3, %iterArg_4 : tensor<5x4xf32> + %8 = stablehlo.dot_general %iterArg_4, %6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<5x4xf32>, tensor<4x4xf32>) -> tensor<5x4xf32> + %9 = stablehlo.add %7, %8 : tensor<5x4xf32> + stablehlo.return %1, %9 : tensor, tensor<5x4xf32> + } + return %0#1 : tensor<5x4xf32> + } +} From 74dd78215794abaad24402e7248314205686d599 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 23 Nov 2025 17:55:42 -0600 Subject: [PATCH 08/11] feat: cleaner printing --- .../jax/Analysis/StructuredMatrixAnalysis.h | 1 + .../jax/Passes/StructuredMatrixSimplify.cpp | 46 +++++++++++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index 4ffbe9fc9..fdc10cc84 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -67,6 +67,7 @@ class StructuredSparsityPattern { refineKind(); } + StructuredSparsityKind getKind() const { return kind; } int64_t getLowerBandwidth() const { return lowerBandwidth; } int64_t getUpperBandwidth() const { return upperBandwidth; } diff --git a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp index 3c5fe8b73..ae46ded44 100644 --- a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp +++ b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp @@ -70,12 +70,48 @@ class StructuredMatrixSimplifyPass continue; } - anyKnown = true; + if (state->getValue().getSparsityPattern().getKind() != + mlir::structure_analysis::StructuredSparsityKind::Unknown) { + anyKnown = true; + } + + enzymexla::StructuredSparsityKind ssKind; + switch (state->getValue().getSparsityPattern().getKind()) { + case mlir::structure_analysis::StructuredSparsityKind::Unknown: + ssKind = enzymexla::StructuredSparsityKind::Unknown; + break; + case mlir::structure_analysis::StructuredSparsityKind::Dense: + ssKind = enzymexla::StructuredSparsityKind::Dense; + break; + case mlir::structure_analysis::StructuredSparsityKind::Band: + ssKind = enzymexla::StructuredSparsityKind::Band; + break; + case mlir::structure_analysis::StructuredSparsityKind::UpperTriangular: + ssKind = enzymexla::StructuredSparsityKind::UpperTriangular; + break; + case mlir::structure_analysis::StructuredSparsityKind::UpperBidiagonal: + ssKind = enzymexla::StructuredSparsityKind::UpperBidiagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::LowerTriangular: + ssKind = enzymexla::StructuredSparsityKind::LowerTriangular; + break; + case mlir::structure_analysis::StructuredSparsityKind::LowerBidiagonal: + ssKind = enzymexla::StructuredSparsityKind::LowerBidiagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::Tridiagonal: + ssKind = enzymexla::StructuredSparsityKind::Tridiagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::Diagonal: + ssKind = enzymexla::StructuredSparsityKind::Diagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::Empty: + ssKind = enzymexla::StructuredSparsityKind::Empty; + break; + } - // TODO: get structured sparsity kind auto structuredSparsityKind = enzymexla::StructuredSparsityPatternAttr::get( - mod.getContext(), enzymexla::StructuredSparsityKind::Unknown, + mod.getContext(), ssKind, state->getValue().getSparsityPattern().getLowerBandwidth(), state->getValue().getSparsityPattern().getUpperBandwidth()); @@ -83,18 +119,22 @@ class StructuredMatrixSimplifyPass structuredValueProperties; auto valueProperties = state->getValue().getProperties(); if (valueProperties.hasUnitDiagonal()) { + anyKnown = true; structuredValueProperties.push_back( enzymexla::StructuredValueProperty::UnitDiagonal); } if (valueProperties.isSymmetric()) { + anyKnown = true; structuredValueProperties.push_back( enzymexla::StructuredValueProperty::Symmetric); } if (valueProperties.isHermitian()) { + anyKnown = true; structuredValueProperties.push_back( enzymexla::StructuredValueProperty::Hermitian); } if (valueProperties.isBroadcastedScalar()) { + anyKnown = true; structuredValueProperties.push_back( enzymexla::StructuredValueProperty::BroadcastedScalar); } From 9580768d2aaaa4b23b29a3dde0b1d5f2edecf067 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 23 Nov 2025 20:08:14 -0600 Subject: [PATCH 09/11] feat: more propagation rules --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 77 +++++++++++++++++-- .../jax/Analysis/StructuredMatrixAnalysis.h | 13 +++- 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index 148ff55a6..e6e729e70 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -33,8 +33,6 @@ StructuredSparsityPattern::StructuredSparsityPattern(Value v) { return; } - llvm::errs() << "TODO: structured sparsity pattern not implemented for " << v - << "\n"; setUnknown(); return; } @@ -130,9 +128,8 @@ StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs, if (rhs.kind == StructuredSparsityKind::Unknown) return lhs; - // for all other cases, we take the min of the bandwidths and refine - auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); - auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); + auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); auto newPattern = StructuredSparsityPattern(lb, ub); newPattern.refineKind(); return newPattern; @@ -151,13 +148,21 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs, rhs.kind == StructuredSparsityKind::Unknown) return StructuredSparsityPattern(StructuredSparsityKind::Unknown); - auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); - auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); + auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); auto newPattern = StructuredSparsityPattern(lb, ub); newPattern.refineKind(); return newPattern; } +StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose( + const StructuredSparsityPattern &op) { + auto newPattern = StructuredSparsityPattern(op.upperBandwidth, + op.lowerBandwidth); + newPattern.refineKind(); + return newPattern; +} + void StructuredSparsityPattern::print(raw_ostream &os) const { switch (kind) { case StructuredSparsityKind::Unknown: @@ -413,6 +418,31 @@ void StructuredMatrixType::print(raw_ostream &os) const { os << ")"; } +StructuredMatrixType StructuredMatrixType::propagateTranspose( + const StructuredMatrixType &op) { + return StructuredMatrixType( + StructuredSparsityPattern::propagateTranspose(op.sparsityPattern), + op.valueProperties); +} + +StructuredMatrixType StructuredMatrixType::propagateAdd( + const StructuredMatrixType &lhs, const StructuredMatrixType &rhs) { + ValueProperties valProps; + // TODO: If one is unit diag and other is zeros, we can propagate the other + // to the unit diag + if (lhs.getProperties().isSymmetric() && rhs.getProperties().isSymmetric()) { + valProps.set(ValueProperty::Symmetric); + } + if (lhs.getProperties().isBroadcastedScalar() && + rhs.getProperties().isBroadcastedScalar()) { + valProps.set(ValueProperty::BroadcastedScalar); + } + + return StructuredMatrixType( + StructuredSparsityPattern::meet(lhs.sparsityPattern, rhs.sparsityPattern), + valProps); +} + //===----------------------------------------------------------------------===// // Lattice Element //===----------------------------------------------------------------------===// @@ -465,6 +495,34 @@ void StructuredMatrixAnalysis::setToEntryState( LogicalResult StructuredMatrixAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { + SmallVector updatedProps(results.size(), false); + SmallVector propagatedProps(results.size()); + + // transpose + if (auto transposeOp = dyn_cast(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateTranspose( + operands[0]->getValue()); + } + + // elementwise + /// add + if (auto addOp = dyn_cast(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateAdd( + operands[0]->getValue(), operands[1]->getValue()); + } + + /// mul + + // finalize + for (size_t i = 0; i < results.size(); i++) { + if (updatedProps[i]) { + results[i]->setValue( + StructuredMatrixType::join(results[i]->getValue(), propagatedProps[i])); + } + } + llvm::errs() << "Visiting operation " << *op << "\n"; for (auto operand : operands) { @@ -472,6 +530,11 @@ LogicalResult StructuredMatrixAnalysis::visitOperation( operand->getValue().print(llvm::errs()); llvm::errs() << "\n"; } + for (auto result : results) { + llvm::errs() << " result: "; + result->getValue().print(llvm::errs()); + llvm::errs() << "\n"; + } llvm::errs() << "\n"; return success(); diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index fdc10cc84..cdc0cfa1d 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -88,6 +88,10 @@ class StructuredSparsityPattern { return os; } + // propagation rules + static StructuredSparsityPattern propagateTranspose( + const StructuredSparsityPattern &op); + private: void initializeBandwidths(); void refineKind(); @@ -203,7 +207,14 @@ class StructuredMatrixType { return os; } - // TODO: propagation rules probably goes in here + // propagation rules + static StructuredMatrixType propagateTranspose(const StructuredMatrixType &op); + + static StructuredMatrixType propagateAdd(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs); + + static StructuredMatrixType propagateMultiply(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs); // TODO: implement queries that check both the sparsity pattern and value // properties and return specific matrix kinds From fea32b16bf414f1760bb86ba9ffd9190f88a9868 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 23 Nov 2025 23:24:52 -0600 Subject: [PATCH 10/11] feat: more propagation rules --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 176 ++++++++++++++---- .../jax/Analysis/StructuredMatrixAnalysis.h | 31 ++- .../jax/Passes/StructuredMatrixSimplify.cpp | 2 +- 3 files changed, 168 insertions(+), 41 deletions(-) diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index e6e729e70..f3c43dfed 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -156,9 +156,9 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs, } StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose( - const StructuredSparsityPattern &op) { - auto newPattern = StructuredSparsityPattern(op.upperBandwidth, - op.lowerBandwidth); + Value val, const StructuredSparsityPattern &op) { + auto newPattern = + StructuredSparsityPattern(op.upperBandwidth, op.lowerBandwidth); newPattern.refineKind(); return newPattern; } @@ -253,11 +253,52 @@ ValueProperties::ValueProperties(Value v) { } } - // TODO: A x A^T will always be symmetric + if (auto dotGeneralOp = dyn_cast(defOp)) { + auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); + auto lhs = dotGeneralOp.getLhs(); + auto rhs = dotGeneralOp.getRhs(); + + if (dotDimNumbers.getLhsBatchingDimensions().size() == 0 && + dotDimNumbers.getRhsBatchingDimensions().size() == 0) { + // lhs == rhs => check for the dimension numbers + if (lhs == rhs) { + if (dotDimNumbers.getLhsContractingDimensions().size() == 1 && + dotDimNumbers.getRhsContractingDimensions().size() == 1 && + dotDimNumbers.getLhsContractingDimensions()[0] == + dotDimNumbers.getRhsContractingDimensions()[0]) { + set(ValueProperty::Symmetric); + } + } + + // check operands are transposed: `A x A^T` and `A^T x A` + if (auto lhsT = lhs.getDefiningOp()) { + if (isTrueTranspose(lhsT) && rhs == lhsT.getOperand()) { + if (dotDimNumbers.getLhsContractingDimensions().size() == 1 && + dotDimNumbers.getRhsContractingDimensions().size() == 1 && + dotDimNumbers.getLhsContractingDimensions()[0] == + 1 - dotDimNumbers.getRhsContractingDimensions()[0]) { + set(ValueProperty::Symmetric); + } + } + } + + if (auto rhsT = rhs.getDefiningOp()) { + if (isTrueTranspose(rhsT) && lhs == rhsT.getOperand()) { + if (dotDimNumbers.getLhsContractingDimensions().size() == 1 && + dotDimNumbers.getRhsContractingDimensions().size() == 1 && + dotDimNumbers.getLhsContractingDimensions()[0] == + 1 - dotDimNumbers.getRhsContractingDimensions()[0]) { + set(ValueProperty::Symmetric); + } + } + } + } + } if (auto bcastOp = dyn_cast(defOp)) { auto operand = bcastOp.getOperand(); - if (cast(operand.getType()).getRank() == 0) { // bcast(scalar) + if (cast(operand.getType()).getRank() == + 0) { // bcast(scalar) if (matchPattern(operand, m_One())) // bcast(1) set(ValueProperty::UnitDiagonal); set(ValueProperty::BroadcastedScalar); @@ -418,31 +459,86 @@ void StructuredMatrixType::print(raw_ostream &os) const { os << ")"; } -StructuredMatrixType StructuredMatrixType::propagateTranspose( - const StructuredMatrixType &op) { +StructuredMatrixType +StructuredMatrixType::propagateTranspose(Value val, + const StructuredMatrixType &op) { return StructuredMatrixType( - StructuredSparsityPattern::propagateTranspose(op.sparsityPattern), + StructuredSparsityPattern::propagateTranspose(val, op.sparsityPattern), op.valueProperties); } -StructuredMatrixType StructuredMatrixType::propagateAdd( - const StructuredMatrixType &lhs, const StructuredMatrixType &rhs) { +StructuredMatrixType +StructuredMatrixType::propagateAdd(Value lhs, Value rhs, + const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType) { ValueProperties valProps; - // TODO: If one is unit diag and other is zeros, we can propagate the other + + // If one is unit diag and other is zeros, we can propagate the other // to the unit diag - if (lhs.getProperties().isSymmetric() && rhs.getProperties().isSymmetric()) { + SplatElementsAttr lhsSplat, rhsSplat; + if (lhsType.getProperties().hasUnitDiagonal() && + matchPattern(lhs, m_Constant(&lhsSplat))) { + if (utils::isZero(lhsSplat.getSplatValue())) { + valProps.set(ValueProperty::UnitDiagonal); + } + } + if (rhsType.getProperties().hasUnitDiagonal() && + matchPattern(rhs, m_Constant(&rhsSplat))) { + if (utils::isZero(rhsSplat.getSplatValue())) { + valProps.set(ValueProperty::UnitDiagonal); + } + } + + if (lhsType.getProperties().isSymmetric() && + rhsType.getProperties().isSymmetric()) { valProps.set(ValueProperty::Symmetric); } - if (lhs.getProperties().isBroadcastedScalar() && - rhs.getProperties().isBroadcastedScalar()) { + if (lhsType.getProperties().isBroadcastedScalar() && + rhsType.getProperties().isBroadcastedScalar()) { valProps.set(ValueProperty::BroadcastedScalar); } return StructuredMatrixType( - StructuredSparsityPattern::meet(lhs.sparsityPattern, rhs.sparsityPattern), + StructuredSparsityPattern::meet(lhsType.sparsityPattern, + rhsType.sparsityPattern), valProps); } +StructuredMatrixType +StructuredMatrixType::propagateMultiply(Value lhs, Value rhs, + const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType) { + return StructuredMatrixType::meet(lhsType, rhsType); +} + +// TODO: we ideally want to special case elementwise ops that preserve certain +// properties +StructuredMatrixType StructuredMatrixType::propagateElementwise( + ArrayRef operands, + SmallVectorImpl &operandsType) { + // TODO: propagate structure + + ValueProperties valueProperties; + // TODO: propagate hermitian + bool allSymmetric = true, allScalar = true; + for (auto opType : operandsType) { + if (!opType.getProperties().isSymmetric()) { + allSymmetric = false; + } + if (!opType.getProperties().isBroadcastedScalar()) { + allScalar = false; + } + } + if (allSymmetric) { + valueProperties.set(ValueProperty::Symmetric); + } + if (allScalar) { + valueProperties.set(ValueProperty::BroadcastedScalar); + } + + return StructuredMatrixType(StructuredSparsityPattern(), valueProperties); +} + //===----------------------------------------------------------------------===// // Lattice Element //===----------------------------------------------------------------------===// @@ -498,11 +594,16 @@ LogicalResult StructuredMatrixAnalysis::visitOperation( SmallVector updatedProps(results.size(), false); SmallVector propagatedProps(results.size()); + SmallVector operandValues(operands.size()); + for (size_t i = 0; i < operands.size(); i++) { + operandValues[i] = operands[i]->getValue(); + } + // transpose if (auto transposeOp = dyn_cast(op)) { updatedProps[0] = true; propagatedProps[0] = StructuredMatrixType::propagateTranspose( - operands[0]->getValue()); + transposeOp.getOperand(), operandValues[0]); } // elementwise @@ -510,33 +611,42 @@ LogicalResult StructuredMatrixAnalysis::visitOperation( if (auto addOp = dyn_cast(op)) { updatedProps[0] = true; propagatedProps[0] = StructuredMatrixType::propagateAdd( - operands[0]->getValue(), operands[1]->getValue()); + addOp.getLhs(), addOp.getRhs(), operandValues[0], operandValues[1]); } /// mul + if (auto mulOp = dyn_cast(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateMultiply( + mulOp.getLhs(), mulOp.getRhs(), operandValues[0], operandValues[1]); + } + + /// fallback for other elementwise ops + if (stablehlo::hasTraitElementwise(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateElementwise( + llvm::to_vector<3>(op->getOperands()), operandValues); + } + + // pass through ops + if (isa(op)) { + updatedProps[0] = true; + propagatedProps[0] = operandValues[0]; + } // finalize for (size_t i = 0; i < results.size(); i++) { if (updatedProps[i]) { - results[i]->setValue( - StructuredMatrixType::join(results[i]->getValue(), propagatedProps[i])); + auto resultOrig = results[i]->getValue(); + auto resultNew = + StructuredMatrixType::join(resultOrig, propagatedProps[i]); + results[i]->setValue(resultNew); + propagateIfChanged(results[i], resultNew == resultOrig + ? ChangeResult::NoChange + : ChangeResult::Change); } } - - llvm::errs() << "Visiting operation " << *op << "\n"; - for (auto operand : operands) { - llvm::errs() << " operand: "; - operand->getValue().print(llvm::errs()); - llvm::errs() << "\n"; - } - for (auto result : results) { - llvm::errs() << " result: "; - result->getValue().print(llvm::errs()); - llvm::errs() << "\n"; - } - llvm::errs() << "\n"; - return success(); } diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index cdc0cfa1d..5fd379513 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -12,6 +12,16 @@ namespace structure_analysis { namespace utils { +static bool isZero(APInt v) { return v.isZero(); } +static bool isZero(APFloat v) { return v.isZero(); } +static bool isZero(Attribute v) { + if (auto intAttr = dyn_cast(v)) + return isZero(intAttr.getValue()); + if (auto floatAttr = dyn_cast(v)) + return isZero(floatAttr.getValue()); + return false; +} + static bool isOne(APInt v) { return v.isOne(); } static bool isOne(APFloat v) { return v.isExactlyValue(1.0); } static bool isOne(Attribute v) { @@ -89,8 +99,8 @@ class StructuredSparsityPattern { } // propagation rules - static StructuredSparsityPattern propagateTranspose( - const StructuredSparsityPattern &op); + static StructuredSparsityPattern + propagateTranspose(Value val, const StructuredSparsityPattern &op); private: void initializeBandwidths(); @@ -208,13 +218,20 @@ class StructuredMatrixType { } // propagation rules - static StructuredMatrixType propagateTranspose(const StructuredMatrixType &op); + static StructuredMatrixType + propagateTranspose(Value val, const StructuredMatrixType &op); + + static StructuredMatrixType propagateAdd(Value lhs, Value rhs, + const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType); - static StructuredMatrixType propagateAdd(const StructuredMatrixType &lhs, - const StructuredMatrixType &rhs); + static StructuredMatrixType + propagateMultiply(Value lhs, Value rhs, const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType); - static StructuredMatrixType propagateMultiply(const StructuredMatrixType &lhs, - const StructuredMatrixType &rhs); + static StructuredMatrixType + propagateElementwise(ArrayRef operands, + SmallVectorImpl &operandsType); // TODO: implement queries that check both the sparsity pattern and value // properties and return specific matrix kinds diff --git a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp index ae46ded44..ac6ed2f22 100644 --- a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp +++ b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp @@ -147,7 +147,7 @@ class StructuredMatrixSimplifyPass } if (anyKnown) { - op->setAttr("structured_sparsity", + op->setAttr("enzymexla.structured_sparsity", ArrayAttr::get(mod.getContext(), structuredSparsityAttrs)); } From 98e44f885bf38eb1cca6205de25a9b7ce11050e1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 27 Nov 2025 15:27:59 -0600 Subject: [PATCH 11/11] feat: structure property inference --- .../jax/Analysis/StructuredMatrixAnalysis.cpp | 130 ++++++++++++++++-- .../jax/Analysis/StructuredMatrixAnalysis.h | 4 +- src/enzyme_ad/jax/Dialect/Dialect.cpp | 5 +- test/lit_tests/structured_tensors/banded.mlir | 25 ++++ 4 files changed, 145 insertions(+), 19 deletions(-) create mode 100644 test/lit_tests/structured_tensors/banded.mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp index f3c43dfed..259b7dbbc 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -1,4 +1,5 @@ #include "src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h" +#include "src/enzyme_ad/jax/Passes/StructuredTensors.h" #include "src/enzyme_ad/jax/Utils.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" @@ -33,6 +34,102 @@ StructuredSparsityPattern::StructuredSparsityPattern(Value v) { return; } + auto vTy = cast(v.getType()); + if (!vTy.hasStaticShape() || vTy.getRank() != 2) { + setUnknown(); + return; + } + int64_t nrows = vTy.getDimSize(0); + int64_t ncols = vTy.getDimSize(1); + + auto defOp = v.getDefiningOp(); + if (!defOp) { + setUnknown(); + return; + } + + if (auto compareOp = dyn_cast(defOp)) { + auto lhs = compareOp.getLhs(); + auto rhs = compareOp.getRhs(); + + auto lhsIotaLikeTensor = enzyme::detectIotaLikeTensor(lhs); + auto rhsIotaLikeTensor = enzyme::detectIotaLikeTensor(rhs); + + if (lhsIotaLikeTensor && rhsIotaLikeTensor) { + auto lhsIotaLikeTensorValue = *lhsIotaLikeTensor; + auto rhsIotaLikeTensorValue = *rhsIotaLikeTensor; + + bool lhsIsRow = lhsIotaLikeTensorValue.dimension == 0; + bool rhsIsRow = rhsIotaLikeTensorValue.dimension == 0; + bool lhsIsCol = lhsIotaLikeTensorValue.dimension == 1; + bool rhsIsCol = rhsIotaLikeTensorValue.dimension == 1; + + auto direction = compareOp.getComparisonDirection(); + int64_t offset; + + if (lhsIsRow && rhsIsCol) { + offset = lhsIotaLikeTensorValue.start - rhsIotaLikeTensorValue.start; + } else if (lhsIsCol && rhsIsRow) { + offset = rhsIotaLikeTensorValue.start - lhsIotaLikeTensorValue.start; + direction = reversedComparisonDirection(direction); + } else { + setUnknown(); + return; + } + + // TODO: verify calculation here + switch (direction) { + case stablehlo::ComparisonDirection::LT: + upperBandwidth = ncols - 1 + offset; + lowerBandwidth = -offset; + break; + case stablehlo::ComparisonDirection::LE: + upperBandwidth = ncols - 1 + offset; + lowerBandwidth = -offset - 1; + break; + case stablehlo::ComparisonDirection::GT: + upperBandwidth = -offset - 1; + lowerBandwidth = nrows - 1 + offset; + break; + case stablehlo::ComparisonDirection::GE: + upperBandwidth = -offset; + lowerBandwidth = nrows - 1 + offset; + break; + case stablehlo::ComparisonDirection::EQ: + // row == col => diagonal with offset + upperBandwidth = std::max(0L, -offset); + lowerBandwidth = std::max(0L, offset); + break; + case stablehlo::ComparisonDirection::NE: + setUnknown(); + return; + } + + lowerBandwidth = std::max(0L, std::min(lowerBandwidth, nrows - 1)); + upperBandwidth = std::max(0L, std::min(upperBandwidth, ncols - 1)); + kind = StructuredSparsityKind::Band; + refineKind(); + return; + } + } + + DenseElementsAttr denseAttr; + if (matchPattern(defOp, m_Constant(&denseAttr))) { + if (denseAttr.isSplat()) { + auto val = denseAttr.getSplatValue(); + if (utils::isZero(val)) { + kind = StructuredSparsityKind::Empty; + return; + } + } + + // TODO: get better sparsity pattern from denseAttr, for now we just + // assume it is dense + kind = StructuredSparsityKind::Dense; + initializeBandwidths(); + return; + } + setUnknown(); return; } @@ -44,6 +141,7 @@ void StructuredSparsityPattern::initializeBandwidths() { case StructuredSparsityKind::Dense: lowerBandwidth = std::numeric_limits::max(); upperBandwidth = std::numeric_limits::max(); + break; case StructuredSparsityKind::Band: llvm_unreachable("constructing band with no bandwidths"); case StructuredSparsityKind::UpperTriangular: @@ -71,6 +169,8 @@ void StructuredSparsityPattern::initializeBandwidths() { upperBandwidth = 0; break; case StructuredSparsityKind::Empty: + lowerBandwidth = -1; + upperBandwidth = -1; break; } } @@ -115,19 +215,20 @@ void StructuredSparsityPattern::refineKind() { } } -// most specific pattern +// intersection of the properties StructuredSparsityPattern StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs, const StructuredSparsityPattern &rhs) { - if (lhs.kind == StructuredSparsityKind::Empty || - rhs.kind == StructuredSparsityKind::Empty) - return StructuredSparsityPattern(StructuredSparsityKind::Empty); - if (lhs.kind == StructuredSparsityKind::Unknown) return rhs; if (rhs.kind == StructuredSparsityKind::Unknown) return lhs; + if (lhs.kind == StructuredSparsityKind::Empty) + return rhs; + if (rhs.kind == StructuredSparsityKind::Empty) + return lhs; + auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); auto newPattern = StructuredSparsityPattern(lb, ub); @@ -135,18 +236,18 @@ StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs, return newPattern; } -// least specific structure containing both +// union of the properties StructuredSparsityPattern StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs, const StructuredSparsityPattern &rhs) { - if (lhs.kind == StructuredSparsityKind::Empty) + if (lhs.kind == StructuredSparsityKind::Unknown) return rhs; - if (rhs.kind == StructuredSparsityKind::Empty) + if (rhs.kind == StructuredSparsityKind::Unknown) return lhs; - if (lhs.kind == StructuredSparsityKind::Unknown || - rhs.kind == StructuredSparsityKind::Unknown) - return StructuredSparsityPattern(StructuredSparsityKind::Unknown); + if (lhs.kind == StructuredSparsityKind::Empty || + rhs.kind == StructuredSparsityKind::Empty) + return StructuredSparsityPattern(StructuredSparsityKind::Empty); auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); @@ -157,6 +258,10 @@ StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs, StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose( Value val, const StructuredSparsityPattern &op) { + if (op.kind == StructuredSparsityKind::Empty || + op.kind == StructuredSparsityKind::Unknown) + return StructuredSparsityPattern(op.kind); + auto newPattern = StructuredSparsityPattern(op.upperBandwidth, op.lowerBandwidth); newPattern.refineKind(); @@ -192,9 +297,6 @@ void StructuredSparsityPattern::print(raw_ostream &os) const { case StructuredSparsityKind::Diagonal: os << "Diagonal"; break; - case StructuredSparsityKind::Empty: - os << "Empty"; - break; } } diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h index 5fd379513..8ccd7360d 100644 --- a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -44,7 +44,6 @@ static bool areEqual(APFloat a, APFloat b) { //===----------------------------------------------------------------------===// enum class StructuredSparsityKind { - Unknown, Dense, Band, UpperTriangular, @@ -53,7 +52,8 @@ enum class StructuredSparsityKind { LowerBidiagonal, Tridiagonal, Diagonal, - Empty, // doesn't really mean anything, but we need it for bottom element + Empty, // denotes that all elements are structural zeros + Unknown, }; // TODO: currently only legal negative value is -1, which means "unknown" diff --git a/src/enzyme_ad/jax/Dialect/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Dialect.cpp index 3c4ed5537..7a1569780 100644 --- a/src/enzyme_ad/jax/Dialect/Dialect.cpp +++ b/src/enzyme_ad/jax/Dialect/Dialect.cpp @@ -103,9 +103,8 @@ void StructuredSparsityAttr::print(::mlir::AsmPrinter &printer) const { auto pattern = getPattern(); auto kind = pattern.getKind(); - // Skip printing kind for Unknown and Empty - bool printKind = (kind != StructuredSparsityKind::Unknown && - kind != StructuredSparsityKind::Empty); + // Skip printing kind for Unknown + bool printKind = kind != StructuredSparsityKind::Unknown; if (printKind) { printer << stringifyStructuredSparsityKind(kind); diff --git a/test/lit_tests/structured_tensors/banded.mlir b/test/lit_tests/structured_tensors/banded.mlir new file mode 100644 index 000000000..4f4644b42 --- /dev/null +++ b/test/lit_tests/structured_tensors/banded.mlir @@ -0,0 +1,25 @@ +// RUN: enzymexlamlir-opt --structured-matrix-simplify %s | FileCheck %s + +// func.func @main1(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { +// %c = stablehlo.constant dense<[[true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [false, true, true, true, true, true, true, true, true, true], [false, false, true, true, true, true, true, true, true, true], [false, false, false, true, true, true, true, true, true, true], [false, false, false, false, true, true, true, true, true, true], [false, false, false, false, false, true, true, true, true, true], [false, false, false, false, false, false, true, true, true, true], [false, false, false, false, false, false, false, true, true, true]]> : tensor<10x10xi1> +// %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32> +// %c_0 = stablehlo.constant dense<[[true, true, true, true, false, false, false, false, false, false], [true, true, true, true, true, false, false, false, false, false], [true, true, true, true, true, true, false, false, false, false], [true, true, true, true, true, true, true, false, false, false], [true, true, true, true, true, true, true, true, false, false], [true, true, true, true, true, true, true, true, true, false], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true]]> : tensor<10x10xi1> +// %0 = stablehlo.select %c_0, %arg0, %cst : tensor<10x10xi1>, tensor<10x10xf32> +// %1 = stablehlo.select %c, %0, %cst : tensor<10x10xi1>, tensor<10x10xf32> +// return %1 : tensor<10x10xf32> : tensor<10x10xi1>, tensor<10x10xf32> +// } + +func.func @main2(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %c = stablehlo.constant dense<2> : tensor<10x10xi64> + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %c_0 = stablehlo.constant dense<-3> : tensor<10x10xi64> + %0 = stablehlo.iota dim = 1 : tensor<10x10xi64> + %1 = stablehlo.iota dim = 0 : tensor<10x10xi64> + %2 = stablehlo.subtract %1, %c_0 : tensor<10x10xi64> + %3 = stablehlo.compare LE, %0, %2 : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<10x10xi1> + %4 = stablehlo.subtract %1, %c : tensor<10x10xi64> + %5 = stablehlo.compare GE, %0, %4 : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<10x10xi1> + %6 = stablehlo.select %3, %arg0, %cst : tensor<10x10xi1>, tensor<10x10xf32> + %7 = stablehlo.select %5, %6, %cst : tensor<10x10xi1>, tensor<10x10xf32> + return %7 : tensor<10x10xf32> +}