diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp new file mode 100644 index 000000000..259b7dbbc --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp @@ -0,0 +1,756 @@ +#include "src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h" +#include "src/enzyme_ad/jax/Passes/StructuredTensors.h" +#include "src/enzyme_ad/jax/Utils.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "stablehlo/dialect/StablehloOps.h" + +using namespace mlir; +using namespace mlir::dataflow; + +namespace mlir { +namespace structure_analysis { + +//===----------------------------------------------------------------------===// +// Structured Sparsity Pattern Implementation +//===----------------------------------------------------------------------===// + +StructuredSparsityPattern::StructuredSparsityPattern(Value v) { + if (auto blockArg = dyn_cast(v)) { + // TODO: If block arg is annotated with a pattern we should parse that + setUnknown(); // be pessimistic by default + return; + } + + auto vTy = cast(v.getType()); + if (!vTy.hasStaticShape() || vTy.getRank() != 2) { + setUnknown(); + return; + } + int64_t nrows = vTy.getDimSize(0); + int64_t ncols = vTy.getDimSize(1); + + auto defOp = v.getDefiningOp(); + if (!defOp) { + setUnknown(); + return; + } + + if (auto compareOp = dyn_cast(defOp)) { + auto lhs = compareOp.getLhs(); + auto rhs = compareOp.getRhs(); + + auto lhsIotaLikeTensor = enzyme::detectIotaLikeTensor(lhs); + auto rhsIotaLikeTensor = enzyme::detectIotaLikeTensor(rhs); + + if (lhsIotaLikeTensor && rhsIotaLikeTensor) { + auto lhsIotaLikeTensorValue = *lhsIotaLikeTensor; + auto rhsIotaLikeTensorValue = *rhsIotaLikeTensor; + + bool lhsIsRow = lhsIotaLikeTensorValue.dimension == 0; + bool rhsIsRow = rhsIotaLikeTensorValue.dimension == 0; + bool lhsIsCol = lhsIotaLikeTensorValue.dimension == 1; + bool rhsIsCol = rhsIotaLikeTensorValue.dimension == 1; + + auto direction = compareOp.getComparisonDirection(); + int64_t offset; + + if (lhsIsRow && rhsIsCol) { + offset = lhsIotaLikeTensorValue.start - rhsIotaLikeTensorValue.start; + } else if (lhsIsCol && rhsIsRow) { + offset = rhsIotaLikeTensorValue.start - lhsIotaLikeTensorValue.start; + direction = reversedComparisonDirection(direction); + } else { + setUnknown(); + return; + } + + // TODO: verify calculation here + switch (direction) { + case stablehlo::ComparisonDirection::LT: + upperBandwidth = ncols - 1 + offset; + lowerBandwidth = -offset; + break; + case stablehlo::ComparisonDirection::LE: + upperBandwidth = ncols - 1 + offset; + lowerBandwidth = -offset - 1; + break; + case stablehlo::ComparisonDirection::GT: + upperBandwidth = -offset - 1; + lowerBandwidth = nrows - 1 + offset; + break; + case stablehlo::ComparisonDirection::GE: + upperBandwidth = -offset; + lowerBandwidth = nrows - 1 + offset; + break; + case stablehlo::ComparisonDirection::EQ: + // row == col => diagonal with offset + upperBandwidth = std::max(0L, -offset); + lowerBandwidth = std::max(0L, offset); + break; + case stablehlo::ComparisonDirection::NE: + setUnknown(); + return; + } + + lowerBandwidth = std::max(0L, std::min(lowerBandwidth, nrows - 1)); + upperBandwidth = std::max(0L, std::min(upperBandwidth, ncols - 1)); + kind = StructuredSparsityKind::Band; + refineKind(); + return; + } + } + + DenseElementsAttr denseAttr; + if (matchPattern(defOp, m_Constant(&denseAttr))) { + if (denseAttr.isSplat()) { + auto val = denseAttr.getSplatValue(); + if (utils::isZero(val)) { + kind = StructuredSparsityKind::Empty; + return; + } + } + + // TODO: get better sparsity pattern from denseAttr, for now we just + // assume it is dense + kind = StructuredSparsityKind::Dense; + initializeBandwidths(); + return; + } + + setUnknown(); + return; +} + +void StructuredSparsityPattern::initializeBandwidths() { + switch (kind) { + case StructuredSparsityKind::Unknown: + break; // leave as is + case StructuredSparsityKind::Dense: + lowerBandwidth = std::numeric_limits::max(); + upperBandwidth = std::numeric_limits::max(); + break; + case StructuredSparsityKind::Band: + llvm_unreachable("constructing band with no bandwidths"); + case StructuredSparsityKind::UpperTriangular: + lowerBandwidth = 0; + upperBandwidth = std::numeric_limits::max(); + break; + case StructuredSparsityKind::UpperBidiagonal: + lowerBandwidth = 0; + upperBandwidth = 1; + break; + case StructuredSparsityKind::LowerTriangular: + lowerBandwidth = std::numeric_limits::max(); + upperBandwidth = 0; + break; + case StructuredSparsityKind::LowerBidiagonal: + lowerBandwidth = 1; + upperBandwidth = 0; + break; + case StructuredSparsityKind::Tridiagonal: + lowerBandwidth = 1; + upperBandwidth = 1; + break; + case StructuredSparsityKind::Diagonal: + lowerBandwidth = 0; + upperBandwidth = 0; + break; + case StructuredSparsityKind::Empty: + lowerBandwidth = -1; + upperBandwidth = -1; + break; + } +} + +void StructuredSparsityPattern::refineKind() { + if (lowerBandwidth == 0) { + if (upperBandwidth == 0) { + kind = StructuredSparsityKind::Diagonal; + return; + } + if (upperBandwidth == 1) { + kind = StructuredSparsityKind::UpperBidiagonal; + return; + } + if (upperBandwidth == std::numeric_limits::max()) { + kind = StructuredSparsityKind::UpperTriangular; + return; + } + } + + // lowerBandwidth != 0 + if (upperBandwidth == 0) { + if (lowerBandwidth == 1) { + kind = StructuredSparsityKind::LowerBidiagonal; + return; + } + if (lowerBandwidth == std::numeric_limits::max()) { + kind = StructuredSparsityKind::LowerTriangular; + return; + } + } + + if (lowerBandwidth == 1 && upperBandwidth == 1) { + kind = StructuredSparsityKind::Tridiagonal; + return; + } + + if (lowerBandwidth == std::numeric_limits::max() && + upperBandwidth == std::numeric_limits::max()) { + kind = StructuredSparsityKind::Dense; + return; + } +} + +// intersection of the properties +StructuredSparsityPattern +StructuredSparsityPattern::meet(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs) { + if (lhs.kind == StructuredSparsityKind::Unknown) + return rhs; + if (rhs.kind == StructuredSparsityKind::Unknown) + return lhs; + + if (lhs.kind == StructuredSparsityKind::Empty) + return rhs; + if (rhs.kind == StructuredSparsityKind::Empty) + return lhs; + + auto lb = std::max(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::max(lhs.upperBandwidth, rhs.upperBandwidth); + auto newPattern = StructuredSparsityPattern(lb, ub); + newPattern.refineKind(); + return newPattern; +} + +// union of the properties +StructuredSparsityPattern +StructuredSparsityPattern::join(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs) { + if (lhs.kind == StructuredSparsityKind::Unknown) + return rhs; + if (rhs.kind == StructuredSparsityKind::Unknown) + return lhs; + + if (lhs.kind == StructuredSparsityKind::Empty || + rhs.kind == StructuredSparsityKind::Empty) + return StructuredSparsityPattern(StructuredSparsityKind::Empty); + + auto lb = std::min(lhs.lowerBandwidth, rhs.lowerBandwidth); + auto ub = std::min(lhs.upperBandwidth, rhs.upperBandwidth); + auto newPattern = StructuredSparsityPattern(lb, ub); + newPattern.refineKind(); + return newPattern; +} + +StructuredSparsityPattern StructuredSparsityPattern::propagateTranspose( + Value val, const StructuredSparsityPattern &op) { + if (op.kind == StructuredSparsityKind::Empty || + op.kind == StructuredSparsityKind::Unknown) + return StructuredSparsityPattern(op.kind); + + auto newPattern = + StructuredSparsityPattern(op.upperBandwidth, op.lowerBandwidth); + newPattern.refineKind(); + return newPattern; +} + +void StructuredSparsityPattern::print(raw_ostream &os) const { + switch (kind) { + case StructuredSparsityKind::Unknown: + os << "Unknown"; + break; + case StructuredSparsityKind::Dense: + os << "Dense"; + break; + case StructuredSparsityKind::Band: + os << "Band(" << lowerBandwidth << ", " << upperBandwidth << ")"; + break; + case StructuredSparsityKind::UpperTriangular: + os << "UpperTriangular"; + break; + case StructuredSparsityKind::UpperBidiagonal: + os << "UpperBidiagonal"; + break; + case StructuredSparsityKind::LowerTriangular: + os << "LowerTriangular"; + break; + case StructuredSparsityKind::LowerBidiagonal: + os << "LowerBidiagonal"; + break; + case StructuredSparsityKind::Tridiagonal: + os << "Tridiagonal"; + break; + case StructuredSparsityKind::Diagonal: + os << "Diagonal"; + break; + } +} + +//===----------------------------------------------------------------------===// +// Value Properties Implementation +//===----------------------------------------------------------------------===// + +ValueProperties::ValueProperties(Value v) { + if (auto blockArg = dyn_cast(v)) { + // TODO: If block arg is annotated with a pattern we should parse that + setFlags(0); // be pessimistic by default + return; + } + + auto vTy = cast(v.getType()); + if (!vTy.hasStaticShape() || vTy.getRank() != 2) + return; + auto vShape = vTy.getShape(); + if (vShape[0] != vShape[1]) // TODO: should we allow rectangular matrices? + return; + + DenseElementsAttr denseAttr; + if (matchPattern(v, m_Constant(&denseAttr))) { + auto props = getPropertiesFromDenseAttr(denseAttr); + setFlags(props.getFlags()); + return; + } + + auto defOp = v.getDefiningOp(); + if (!defOp) + return; + + // check that transpose dimensions are [1,0] + auto isTrueTranspose = [](stablehlo::TransposeOp tOp) -> bool { + auto perm = tOp.getPermutation(); + return perm.size() == 2 && perm[0] == 1 && perm[1] == 0; + }; + + // comm_op(A, A^T) will always be symmetric + if (stablehlo::hasTraitElementwise(defOp) && + (defOp->hasTrait() || + defOp->hasTrait())) { + auto lhs = defOp->getOperand(0); + auto rhs = defOp->getOperand(1); + + if (auto rhsT = rhs.getDefiningOp()) { + if (isTrueTranspose(rhsT) && lhs == rhsT.getOperand()) { + set(ValueProperty::Symmetric); + } + } + + if (auto lhsT = lhs.getDefiningOp()) { + if (isTrueTranspose(lhsT) && rhs == lhsT.getOperand()) { + set(ValueProperty::Symmetric); + } + } + } + + if (auto dotGeneralOp = dyn_cast(defOp)) { + auto dotDimNumbers = dotGeneralOp.getDotDimensionNumbers(); + auto lhs = dotGeneralOp.getLhs(); + auto rhs = dotGeneralOp.getRhs(); + + if (dotDimNumbers.getLhsBatchingDimensions().size() == 0 && + dotDimNumbers.getRhsBatchingDimensions().size() == 0) { + // lhs == rhs => check for the dimension numbers + if (lhs == rhs) { + if (dotDimNumbers.getLhsContractingDimensions().size() == 1 && + dotDimNumbers.getRhsContractingDimensions().size() == 1 && + dotDimNumbers.getLhsContractingDimensions()[0] == + dotDimNumbers.getRhsContractingDimensions()[0]) { + set(ValueProperty::Symmetric); + } + } + + // check operands are transposed: `A x A^T` and `A^T x A` + if (auto lhsT = lhs.getDefiningOp()) { + if (isTrueTranspose(lhsT) && rhs == lhsT.getOperand()) { + if (dotDimNumbers.getLhsContractingDimensions().size() == 1 && + dotDimNumbers.getRhsContractingDimensions().size() == 1 && + dotDimNumbers.getLhsContractingDimensions()[0] == + 1 - dotDimNumbers.getRhsContractingDimensions()[0]) { + set(ValueProperty::Symmetric); + } + } + } + + if (auto rhsT = rhs.getDefiningOp()) { + if (isTrueTranspose(rhsT) && lhs == rhsT.getOperand()) { + if (dotDimNumbers.getLhsContractingDimensions().size() == 1 && + dotDimNumbers.getRhsContractingDimensions().size() == 1 && + dotDimNumbers.getLhsContractingDimensions()[0] == + 1 - dotDimNumbers.getRhsContractingDimensions()[0]) { + set(ValueProperty::Symmetric); + } + } + } + } + } + + if (auto bcastOp = dyn_cast(defOp)) { + auto operand = bcastOp.getOperand(); + if (cast(operand.getType()).getRank() == + 0) { // bcast(scalar) + if (matchPattern(operand, m_One())) // bcast(1) + set(ValueProperty::UnitDiagonal); + set(ValueProperty::BroadcastedScalar); + set(ValueProperty::Symmetric); + return; + } + } + + // TODO: unit diagonal + // - iota scatter with constant + + return; +} + +ValueProperties +ValueProperties::getPropertiesFromDenseAttr(DenseElementsAttr attr) { + ValueProperties props; + + if (attr.isSplat()) { + auto val = attr.getSplatValue(); + if (utils::isOne(val)) + props.set(ValueProperty::UnitDiagonal); + + props.set(ValueProperty::BroadcastedScalar); + props.set(ValueProperty::Symmetric); + props.set(ValueProperty::Hermitian); + return props; + } + + auto type = dyn_cast(attr.getType()); + if (!type) + return props; + + auto shape = type.getShape(); + int64_t nrows = shape[0]; + int64_t ncols = shape[1]; + if (nrows != ncols) + return props; + + if (isUnitDiagonal(attr, nrows, ncols)) + props.set(ValueProperty::UnitDiagonal); + + auto [isSymmetric, isHermitian] = isSymmetricOrHermitian(attr, nrows, ncols); + if (isSymmetric) + props.set(ValueProperty::Symmetric); + if (isHermitian) + props.set(ValueProperty::Hermitian); + + return props; +} + +template +bool isUnitDiagonalImpl(DenseElementsAttr attr, int64_t nrows, int64_t ncols) { + auto values = attr.getValues().begin(); + for (int64_t i = 0; i < std::min(nrows, ncols); i++) { + if (!utils::isOne(values[i])) + return false; + } + return true; +} + +bool ValueProperties::isUnitDiagonal(DenseElementsAttr attr, int64_t nrows, + int64_t ncols) { + if (isa(attr.getElementType())) { + return isUnitDiagonalImpl(attr, nrows, ncols); + } else if (isa(attr.getElementType())) { + return isUnitDiagonalImpl(attr, nrows, ncols); + } + return false; +} + +template +std::tuple isSymmetricOrHermitianImpl(DenseElementsAttr attr, + int64_t nrows, + int64_t ncols) { + auto values = attr.getValues().begin(); + for (int64_t i = 0; i < nrows; i++) { + for (int64_t j = i + 1; j < ncols; j++) { + auto a = *(values + i * ncols + j); + auto b = *(values + j * ncols + i); + if (!utils::areEqual(a, b)) { + return {false, false}; // TODO: check for hermitian + } + } + } + + return {true, false}; // TODO: check for hermitian +} + +std::tuple +ValueProperties::isSymmetricOrHermitian(DenseElementsAttr attr, int64_t nrows, + int64_t ncols) { + if (isa(attr.getElementType())) { + return isSymmetricOrHermitianImpl(attr, nrows, ncols); + } else if (isa(attr.getElementType())) { + return isSymmetricOrHermitianImpl(attr, nrows, ncols); + } + return {false, false}; +} + +ValueProperties ValueProperties::meet(const ValueProperties &lhs, + const ValueProperties &rhs) { + return ValueProperties(lhs.flags & rhs.flags); +} + +ValueProperties ValueProperties::join(const ValueProperties &lhs, + const ValueProperties &rhs) { + return ValueProperties(lhs.flags | rhs.flags); +} + +void ValueProperties::print(raw_ostream &os) const { + os << "{"; + bool first = true; + auto add = [&](const char *s) { + if (!first) + os << ", "; + os << s; + first = false; + }; + + if (hasUnitDiagonal()) + add("UnitDiagonal"); + if (isSymmetric()) + add("Symmetric"); + if (isHermitian()) + add("Hermitian"); + if (isBroadcastedScalar()) + add("BroadcastedScalar"); + + os << "}"; +} + +//===----------------------------------------------------------------------===// +// Structured Matrix Type +//===----------------------------------------------------------------------===// + +StructuredMatrixType +StructuredMatrixType::meet(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs) { + return StructuredMatrixType( + StructuredSparsityPattern::meet(lhs.sparsityPattern, rhs.sparsityPattern), + ValueProperties::meet(lhs.valueProperties, rhs.valueProperties)); +} + +StructuredMatrixType +StructuredMatrixType::join(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs) { + return StructuredMatrixType( + StructuredSparsityPattern::join(lhs.sparsityPattern, rhs.sparsityPattern), + ValueProperties::join(lhs.valueProperties, rhs.valueProperties)); +} + +void StructuredMatrixType::print(raw_ostream &os) const { + os << "StructuredMatrixType("; + sparsityPattern.print(os); + os << " "; + valueProperties.print(os); + os << ")"; +} + +StructuredMatrixType +StructuredMatrixType::propagateTranspose(Value val, + const StructuredMatrixType &op) { + return StructuredMatrixType( + StructuredSparsityPattern::propagateTranspose(val, op.sparsityPattern), + op.valueProperties); +} + +StructuredMatrixType +StructuredMatrixType::propagateAdd(Value lhs, Value rhs, + const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType) { + ValueProperties valProps; + + // If one is unit diag and other is zeros, we can propagate the other + // to the unit diag + SplatElementsAttr lhsSplat, rhsSplat; + if (lhsType.getProperties().hasUnitDiagonal() && + matchPattern(lhs, m_Constant(&lhsSplat))) { + if (utils::isZero(lhsSplat.getSplatValue())) { + valProps.set(ValueProperty::UnitDiagonal); + } + } + if (rhsType.getProperties().hasUnitDiagonal() && + matchPattern(rhs, m_Constant(&rhsSplat))) { + if (utils::isZero(rhsSplat.getSplatValue())) { + valProps.set(ValueProperty::UnitDiagonal); + } + } + + if (lhsType.getProperties().isSymmetric() && + rhsType.getProperties().isSymmetric()) { + valProps.set(ValueProperty::Symmetric); + } + if (lhsType.getProperties().isBroadcastedScalar() && + rhsType.getProperties().isBroadcastedScalar()) { + valProps.set(ValueProperty::BroadcastedScalar); + } + + return StructuredMatrixType( + StructuredSparsityPattern::meet(lhsType.sparsityPattern, + rhsType.sparsityPattern), + valProps); +} + +StructuredMatrixType +StructuredMatrixType::propagateMultiply(Value lhs, Value rhs, + const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType) { + return StructuredMatrixType::meet(lhsType, rhsType); +} + +// TODO: we ideally want to special case elementwise ops that preserve certain +// properties +StructuredMatrixType StructuredMatrixType::propagateElementwise( + ArrayRef operands, + SmallVectorImpl &operandsType) { + // TODO: propagate structure + + ValueProperties valueProperties; + // TODO: propagate hermitian + bool allSymmetric = true, allScalar = true; + for (auto opType : operandsType) { + if (!opType.getProperties().isSymmetric()) { + allSymmetric = false; + } + if (!opType.getProperties().isBroadcastedScalar()) { + allScalar = false; + } + } + if (allSymmetric) { + valueProperties.set(ValueProperty::Symmetric); + } + if (allScalar) { + valueProperties.set(ValueProperty::BroadcastedScalar); + } + + return StructuredMatrixType(StructuredSparsityPattern(), valueProperties); +} + +//===----------------------------------------------------------------------===// +// Lattice Element +//===----------------------------------------------------------------------===// + +ChangeResult StructuredMatrixLattice::meet(const AbstractSparseLattice &rhs) { + const auto *rhsStruct = + reinterpret_cast(&rhs); + return meet(*rhsStruct); +} + +ChangeResult StructuredMatrixLattice::meet(StructuredMatrixLattice rhs) { + auto newValue = StructuredMatrixType::meet(getValue(), rhs.getValue()); + if (getValue() == newValue) + return ChangeResult::NoChange; + + setValue(newValue); + return ChangeResult::Change; +} + +ChangeResult StructuredMatrixLattice::join(const AbstractSparseLattice &rhs) { + const auto *rhsStruct = + reinterpret_cast(&rhs); + return join(*rhsStruct); +} + +ChangeResult StructuredMatrixLattice::join(StructuredMatrixLattice rhs) { + auto newValue = StructuredMatrixType::join(getValue(), rhs.getValue()); + if (getValue() == newValue) + return ChangeResult::NoChange; + + setValue(newValue); + return ChangeResult::Change; +} + +void StructuredMatrixLattice::print(raw_ostream &os) const { + os << "StructuredMatrixLattice("; + value.print(os); + os << ")"; +} + +//===----------------------------------------------------------------------===// +// Dataflow Analysis +//===----------------------------------------------------------------------===// + +void StructuredMatrixAnalysis::setToEntryState( + StructuredMatrixLattice *lattice) { + lattice->setValue(StructuredMatrixType()); +} + +LogicalResult StructuredMatrixAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + SmallVector updatedProps(results.size(), false); + SmallVector propagatedProps(results.size()); + + SmallVector operandValues(operands.size()); + for (size_t i = 0; i < operands.size(); i++) { + operandValues[i] = operands[i]->getValue(); + } + + // transpose + if (auto transposeOp = dyn_cast(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateTranspose( + transposeOp.getOperand(), operandValues[0]); + } + + // elementwise + /// add + if (auto addOp = dyn_cast(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateAdd( + addOp.getLhs(), addOp.getRhs(), operandValues[0], operandValues[1]); + } + + /// mul + if (auto mulOp = dyn_cast(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateMultiply( + mulOp.getLhs(), mulOp.getRhs(), operandValues[0], operandValues[1]); + } + + /// fallback for other elementwise ops + if (stablehlo::hasTraitElementwise(op)) { + updatedProps[0] = true; + propagatedProps[0] = StructuredMatrixType::propagateElementwise( + llvm::to_vector<3>(op->getOperands()), operandValues); + } + + // pass through ops + if (isa(op)) { + updatedProps[0] = true; + propagatedProps[0] = operandValues[0]; + } + + // finalize + for (size_t i = 0; i < results.size(); i++) { + if (updatedProps[i]) { + auto resultOrig = results[i]->getValue(); + auto resultNew = + StructuredMatrixType::join(resultOrig, propagatedProps[i]); + results[i]->setValue(resultNew); + propagateIfChanged(results[i], resultNew == resultOrig + ? ChangeResult::NoChange + : ChangeResult::Change); + } + } + + return success(); +} + +} // namespace structure_analysis +} // namespace mlir diff --git a/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h new file mode 100644 index 000000000..8ccd7360d --- /dev/null +++ b/src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h @@ -0,0 +1,292 @@ +#pragma once + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + +#include "src/enzyme_ad/jax/Dialect/Ops.h" + +#include +#include + +namespace mlir { +namespace structure_analysis { + +namespace utils { + +static bool isZero(APInt v) { return v.isZero(); } +static bool isZero(APFloat v) { return v.isZero(); } +static bool isZero(Attribute v) { + if (auto intAttr = dyn_cast(v)) + return isZero(intAttr.getValue()); + if (auto floatAttr = dyn_cast(v)) + return isZero(floatAttr.getValue()); + return false; +} + +static bool isOne(APInt v) { return v.isOne(); } +static bool isOne(APFloat v) { return v.isExactlyValue(1.0); } +static bool isOne(Attribute v) { + if (auto intAttr = dyn_cast(v)) + return isOne(intAttr.getValue()); + if (auto floatAttr = dyn_cast(v)) + return isOne(floatAttr.getValue()); + return false; +} + +static bool areEqual(APInt a, APInt b) { return a == b; } +static bool areEqual(APFloat a, APFloat b) { + return a.compare(b) == llvm::APFloat::cmpEqual; +} + +} // namespace utils + +//===----------------------------------------------------------------------===// +// Structured Sparsity Pattern Implementation +//===----------------------------------------------------------------------===// + +enum class StructuredSparsityKind { + Dense, + Band, + UpperTriangular, + UpperBidiagonal, + LowerTriangular, + LowerBidiagonal, + Tridiagonal, + Diagonal, + Empty, // denotes that all elements are structural zeros + Unknown, +}; + +// TODO: currently only legal negative value is -1, which means "unknown" +// we should support negative bandwidths +class StructuredSparsityPattern { +public: + StructuredSparsityPattern() + : kind(StructuredSparsityKind::Unknown), lowerBandwidth(-1), + upperBandwidth(-1) {} + + explicit StructuredSparsityPattern(StructuredSparsityKind kind) + : kind(kind), lowerBandwidth(-1), upperBandwidth(-1) { + initializeBandwidths(); + } + + StructuredSparsityPattern(Value v); + + StructuredSparsityPattern(int64_t lowerBandwidth, int64_t upperBandwidth) + : kind(StructuredSparsityKind::Band), lowerBandwidth(lowerBandwidth), + upperBandwidth(upperBandwidth) { + refineKind(); + } + + StructuredSparsityKind getKind() const { return kind; } + int64_t getLowerBandwidth() const { return lowerBandwidth; } + int64_t getUpperBandwidth() const { return upperBandwidth; } + + static StructuredSparsityPattern meet(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs); + + static StructuredSparsityPattern join(const StructuredSparsityPattern &lhs, + const StructuredSparsityPattern &rhs); + + bool operator==(const StructuredSparsityPattern &other) const { + return kind == other.kind && lowerBandwidth == other.lowerBandwidth && + upperBandwidth == other.upperBandwidth; + } + + void print(raw_ostream &os) const; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } + + // propagation rules + static StructuredSparsityPattern + propagateTranspose(Value val, const StructuredSparsityPattern &op); + +private: + void initializeBandwidths(); + void refineKind(); + + void setUnknown() { + kind = StructuredSparsityKind::Unknown; + lowerBandwidth = -1; + upperBandwidth = -1; + } + + StructuredSparsityKind kind; + int64_t lowerBandwidth; + int64_t upperBandwidth; +}; + +//===----------------------------------------------------------------------===// +// Value Properties Implementation +//===----------------------------------------------------------------------===// + +enum class ValueProperty { + UnitDiagonal = 1 << 0, + Symmetric = 1 << 1, + Hermitian = 1 << 2, + BroadcastedScalar = 1 << 3, +}; + +class ValueProperties { +public: + ValueProperties() = default; + explicit ValueProperties(uint32_t flags) : flags(flags) {} + + ValueProperties(Value v); + + void set(ValueProperty property) { flags |= static_cast(property); } + void clear(ValueProperty property) { + flags &= ~static_cast(property); + } + bool has(ValueProperty property) const { + return flags & static_cast(property); + } + + bool hasUnitDiagonal() const { return has(ValueProperty::UnitDiagonal); } + bool isSymmetric() const { return has(ValueProperty::Symmetric); } + bool isHermitian() const { return has(ValueProperty::Hermitian); } + bool isBroadcastedScalar() const { + return has(ValueProperty::BroadcastedScalar); + } + + void print(raw_ostream &os) const; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } + + uint32_t getFlags() const { return flags; } + void setFlags(uint32_t f) { flags = f; } + + static ValueProperties meet(const ValueProperties &lhs, + const ValueProperties &rhs); + + static ValueProperties join(const ValueProperties &lhs, + const ValueProperties &rhs); + + bool operator==(const ValueProperties &other) const { + return flags == other.flags; + } + +private: + static ValueProperties getPropertiesFromDenseAttr(DenseElementsAttr attr); + + static bool isUnitDiagonal(DenseElementsAttr attr, int64_t nrows, + int64_t ncols); + static std::tuple + isSymmetricOrHermitian(DenseElementsAttr, int64_t nrows, int64_t ncols); + + uint32_t flags = 0; +}; + +//===----------------------------------------------------------------------===// +// Structured Matrix Type +//===----------------------------------------------------------------------===// + +class StructuredMatrixType { +public: + StructuredMatrixType() = default; + StructuredMatrixType(StructuredSparsityPattern sparsityPattern, + ValueProperties valueProperties) + : sparsityPattern(sparsityPattern), valueProperties(valueProperties) {} + + StructuredMatrixType(Value v) + : StructuredMatrixType(StructuredSparsityPattern(v), ValueProperties(v)) { + } + + const StructuredSparsityPattern &getSparsityPattern() const { + return sparsityPattern; + } + const ValueProperties &getProperties() const { return valueProperties; } + + static StructuredMatrixType meet(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs); + + static StructuredMatrixType join(const StructuredMatrixType &lhs, + const StructuredMatrixType &rhs); + + bool operator==(const StructuredMatrixType &other) const { + return sparsityPattern == other.sparsityPattern && + valueProperties == other.valueProperties; + } + + void print(raw_ostream &os) const; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } + + // propagation rules + static StructuredMatrixType + propagateTranspose(Value val, const StructuredMatrixType &op); + + static StructuredMatrixType propagateAdd(Value lhs, Value rhs, + const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType); + + static StructuredMatrixType + propagateMultiply(Value lhs, Value rhs, const StructuredMatrixType &lhsType, + const StructuredMatrixType &rhsType); + + static StructuredMatrixType + propagateElementwise(ArrayRef operands, + SmallVectorImpl &operandsType); + + // TODO: implement queries that check both the sparsity pattern and value + // properties and return specific matrix kinds + +private: + StructuredSparsityPattern sparsityPattern; + ValueProperties valueProperties; +}; + +//===----------------------------------------------------------------------===// +// Lattice Element +//===----------------------------------------------------------------------===// + +class StructuredMatrixLattice : public dataflow::AbstractSparseLattice { +public: + using AbstractSparseLattice::AbstractSparseLattice; + + StructuredMatrixLattice(Value v) + : AbstractSparseLattice(v), value(StructuredMatrixType(v)) {} + + ChangeResult meet(const AbstractSparseLattice &rhs) override; + ChangeResult meet(StructuredMatrixLattice rhs); + + ChangeResult join(const AbstractSparseLattice &rhs) override; + ChangeResult join(StructuredMatrixLattice rhs); + + void print(raw_ostream &os) const override; + raw_ostream &operator<<(raw_ostream &os) const { + print(os); + return os; + } + + const StructuredMatrixType &getValue() const { return value; } + void setValue(const StructuredMatrixType &v) { value = v; } + +private: + StructuredMatrixType value; +}; + +//===----------------------------------------------------------------------===// +// Dataflow Analysis +//===----------------------------------------------------------------------===// + +class StructuredMatrixAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(StructuredMatrixLattice *lattice) override; + + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; +}; + +} // namespace structure_analysis +} // namespace mlir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 975f166f9..cf247db58 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -839,6 +839,7 @@ cc_library( cc_library( name = "XLADerivatives", srcs = glob([ + "Analysis/*.cpp", "Implementations/*.cpp", "Passes/*.cpp", "Dialect/*.cpp", @@ -848,6 +849,7 @@ cc_library( "Utils.cpp", ], hdrs = glob([ + "Analysis/*.h", "Implementations/*.h", "Passes/*.h", "Dialect/*.h", diff --git a/src/enzyme_ad/jax/Dialect/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Dialect.cpp index d9e0e1375..7a1569780 100644 --- a/src/enzyme_ad/jax/Dialect/Dialect.cpp +++ b/src/enzyme_ad/jax/Dialect/Dialect.cpp @@ -97,6 +97,122 @@ struct EnzymeXLADialectInlinerInterface : public DialectInlinerInterface { } // namespace +void StructuredSparsityAttr::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + + auto pattern = getPattern(); + auto kind = pattern.getKind(); + + // Skip printing kind for Unknown + bool printKind = kind != StructuredSparsityKind::Unknown; + + if (printKind) { + printer << stringifyStructuredSparsityKind(kind); + + // Print bandwidth only for Band + if (kind == StructuredSparsityKind::Band) { + printer << " [" << pattern.getLowerBandwidth() << ", " + << pattern.getUpperBandwidth() << "]"; + } + } + + // Print value properties as a set + auto props = getValueProperties(); + if (!props.empty()) { + if (printKind) { + printer << ", "; + } + printer << "{"; + llvm::interleaveComma(props, printer, [&](StructuredValueProperty prop) { + printer << stringifyStructuredValueProperty(prop); + }); + printer << "}"; + } + + printer << ">"; +} + +::mlir::Attribute StructuredSparsityAttr::parse(::mlir::AsmParser &parser, + ::mlir::Type type) { + if (parser.parseLess()) + return {}; + + StructuredSparsityKind kind = StructuredSparsityKind::Unknown; + int64_t lowerBandwidth = -1; + int64_t upperBandwidth = -1; + llvm::SmallVector valueProperties; + + // Check if we start with a brace (properties only, no kind) + if (parser.parseOptionalLBrace().failed()) { + // Try to parse the kind + llvm::StringRef kindStr; + if (succeeded(parser.parseOptionalKeyword(&kindStr))) { + auto kindOpt = symbolizeStructuredSparsityKind(kindStr); + if (!kindOpt) { + parser.emitError(parser.getCurrentLocation(), "invalid sparsity kind: ") + << kindStr; + return {}; + } + kind = *kindOpt; + + // If Band, parse bandwidth bounds + if (kind == StructuredSparsityKind::Band) { + if (parser.parseLSquare() || parser.parseInteger(lowerBandwidth) || + parser.parseComma() || parser.parseInteger(upperBandwidth) || + parser.parseRSquare()) { + return {}; + } + } + + // Check for comma before properties + parser.parseOptionalComma(); + } + + // Try to parse value properties set + if (succeeded(parser.parseOptionalLBrace())) { + // Fall through to parse properties + } else { + // No properties + if (parser.parseGreater()) + return {}; + + auto pattern = StructuredSparsityPatternAttr::get( + parser.getContext(), kind, lowerBandwidth, upperBandwidth); + return StructuredSparsityAttr::get(parser.getContext(), pattern, + valueProperties); + } + } + + // Parse properties + if (!parser.parseOptionalRBrace().succeeded()) { + do { + llvm::StringRef propStr; + if (parser.parseKeyword(&propStr)) + return {}; + + auto propOpt = symbolizeStructuredValueProperty(propStr); + if (!propOpt) { + parser.emitError(parser.getCurrentLocation(), + "invalid value property: ") + << propStr; + return {}; + } + valueProperties.push_back(*propOpt); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRBrace()) + return {}; + } + + if (parser.parseGreater()) + return {}; + + auto pattern = StructuredSparsityPatternAttr::get( + parser.getContext(), kind, lowerBandwidth, upperBandwidth); + return StructuredSparsityAttr::get(parser.getContext(), pattern, + valueProperties); +} + void EnzymeXLADialect::initialize() { addInterfaces(); addOperations< diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td index 11bf8ca25..b130c84d6 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td @@ -133,4 +133,77 @@ def EnzymeXLA_GuaranteedAnalysisResult : I32EnumAttr<"GuaranteedAnalysisResult", def EnzymeXLA_GuaranteedAnalysisResultAttr : EnumAttr; +def EnzymeXLA_StructuredSparsityKind : I32EnumAttr<"StructuredSparsityKind", + "Kind of sparsity pattern", + [ + I32EnumAttrCase<"Unknown", 0>, + I32EnumAttrCase<"Dense", 1>, + I32EnumAttrCase<"Band", 2>, + I32EnumAttrCase<"UpperTriangular", 3>, + I32EnumAttrCase<"UpperBidiagonal", 4>, + I32EnumAttrCase<"LowerTriangular", 5>, + I32EnumAttrCase<"LowerBidiagonal", 6>, + I32EnumAttrCase<"Tridiagonal", 7>, + I32EnumAttrCase<"Diagonal", 8>, + I32EnumAttrCase<"Empty", 9> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_StructuredSparsityKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def EnzymeXLA_StructuredSparsityPattern : AttrDef { + let summary = "Structured Sparsity Pattern"; + + let parameters = (ins + "::mlir::enzymexla::StructuredSparsityKind":$kind, + "int64_t":$lowerBandwidth, + "int64_t":$upperBandwidth + ); + + let assemblyFormat = [{ + `<` $kind ` ` `[` $lowerBandwidth `,` ` ` $upperBandwidth `]` `>` + }]; + + let mnemonic = "structured_sparsity_pattern"; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_StructuredValueProperty : I32EnumAttr<"StructuredValueProperty", + "Value properties", + [ + I32EnumAttrCase<"UnitDiagonal", 0>, + I32EnumAttrCase<"Symmetric", 1>, + I32EnumAttrCase<"Hermitian", 2>, + I32EnumAttrCase<"BroadcastedScalar", 3> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_StructuredValuePropertyAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def EnzymeXLA_StructuredSparsity : AttrDef { + let summary = "Structured Sparsity"; + let cppNamespace = "::mlir::enzymexla"; + + let parameters = (ins + EnzymeXLA_StructuredSparsityPattern:$pattern, + ArrayRefParameter<"StructuredValueProperty">:$valueProperties + ); + + let mnemonic = "structured_sparsity"; + + let hasCustomAssemblyFormat = 1; +} + #endif // ENZYMEXLA_ATTRS diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 5bf009201..b29fa212f 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1077,4 +1077,13 @@ def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { ]; } +def StructuredMatrixSimplifyPass : Pass<"structured-matrix-simplify", "ModuleOp"> { + let summary = "Simplify structured matrix operations"; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "enzymexla::EnzymeXLADialect", + "func::FuncDialect", + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp new file mode 100644 index 000000000..ac6ed2f22 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/StructuredMatrixSimplify.cpp @@ -0,0 +1,161 @@ +#include "src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "stablehlo/dialect/StablehloOps.h" + +#define DEBUG_TYPE "structured-matrix-simplify" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_STRUCTUREDMATRIXSIMPLIFYPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; +using namespace mlir::enzyme; +using namespace mlir::structure_analysis; + +namespace { + +class StructuredMatrixSimplifyPass + : public enzyme::impl::StructuredMatrixSimplifyPassBase< + StructuredMatrixSimplifyPass> { +public: + using Base::Base; + + void runOnOperation() override { + DataFlowSolver solver; + + solver.load(); + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(getOperation()))) { + return signalPassFailure(); + } + + auto mod = getOperation(); + + // TODO: make IR annotation optional via an option + mod->walk([&](Operation *op) { + SmallVector structuredSparsityAttrs; + bool anyKnown = false; + for (auto result : op->getResults()) { + auto *state = + solver.lookupState( + result); + if (!state) { + structuredSparsityAttrs.push_back( + enzymexla::StructuredSparsityAttr::get( + mod.getContext(), + enzymexla::StructuredSparsityPatternAttr::get( + mod.getContext(), + enzymexla::StructuredSparsityKind::Unknown, -1, -1), + SmallVector())); + continue; + } + + if (state->getValue().getSparsityPattern().getKind() != + mlir::structure_analysis::StructuredSparsityKind::Unknown) { + anyKnown = true; + } + + enzymexla::StructuredSparsityKind ssKind; + switch (state->getValue().getSparsityPattern().getKind()) { + case mlir::structure_analysis::StructuredSparsityKind::Unknown: + ssKind = enzymexla::StructuredSparsityKind::Unknown; + break; + case mlir::structure_analysis::StructuredSparsityKind::Dense: + ssKind = enzymexla::StructuredSparsityKind::Dense; + break; + case mlir::structure_analysis::StructuredSparsityKind::Band: + ssKind = enzymexla::StructuredSparsityKind::Band; + break; + case mlir::structure_analysis::StructuredSparsityKind::UpperTriangular: + ssKind = enzymexla::StructuredSparsityKind::UpperTriangular; + break; + case mlir::structure_analysis::StructuredSparsityKind::UpperBidiagonal: + ssKind = enzymexla::StructuredSparsityKind::UpperBidiagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::LowerTriangular: + ssKind = enzymexla::StructuredSparsityKind::LowerTriangular; + break; + case mlir::structure_analysis::StructuredSparsityKind::LowerBidiagonal: + ssKind = enzymexla::StructuredSparsityKind::LowerBidiagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::Tridiagonal: + ssKind = enzymexla::StructuredSparsityKind::Tridiagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::Diagonal: + ssKind = enzymexla::StructuredSparsityKind::Diagonal; + break; + case mlir::structure_analysis::StructuredSparsityKind::Empty: + ssKind = enzymexla::StructuredSparsityKind::Empty; + break; + } + + auto structuredSparsityKind = + enzymexla::StructuredSparsityPatternAttr::get( + mod.getContext(), ssKind, + state->getValue().getSparsityPattern().getLowerBandwidth(), + state->getValue().getSparsityPattern().getUpperBandwidth()); + + SmallVector + structuredValueProperties; + auto valueProperties = state->getValue().getProperties(); + if (valueProperties.hasUnitDiagonal()) { + anyKnown = true; + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::UnitDiagonal); + } + if (valueProperties.isSymmetric()) { + anyKnown = true; + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::Symmetric); + } + if (valueProperties.isHermitian()) { + anyKnown = true; + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::Hermitian); + } + if (valueProperties.isBroadcastedScalar()) { + anyKnown = true; + structuredValueProperties.push_back( + enzymexla::StructuredValueProperty::BroadcastedScalar); + } + + auto structuredSparsity = enzymexla::StructuredSparsityAttr::get( + mod.getContext(), structuredSparsityKind, + structuredValueProperties); + + structuredSparsityAttrs.push_back(structuredSparsity); + } + + if (anyKnown) { + op->setAttr("enzymexla.structured_sparsity", + ArrayAttr::get(mod.getContext(), structuredSparsityAttrs)); + } + + return WalkResult::advance(); + }); + + // TODO: do things here + } +}; + +} // namespace diff --git a/test/lit_tests/structured_tensors/banded.mlir b/test/lit_tests/structured_tensors/banded.mlir new file mode 100644 index 000000000..4f4644b42 --- /dev/null +++ b/test/lit_tests/structured_tensors/banded.mlir @@ -0,0 +1,25 @@ +// RUN: enzymexlamlir-opt --structured-matrix-simplify %s | FileCheck %s + +// func.func @main1(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { +// %c = stablehlo.constant dense<[[true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [false, true, true, true, true, true, true, true, true, true], [false, false, true, true, true, true, true, true, true, true], [false, false, false, true, true, true, true, true, true, true], [false, false, false, false, true, true, true, true, true, true], [false, false, false, false, false, true, true, true, true, true], [false, false, false, false, false, false, true, true, true, true], [false, false, false, false, false, false, false, true, true, true]]> : tensor<10x10xi1> +// %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32> +// %c_0 = stablehlo.constant dense<[[true, true, true, true, false, false, false, false, false, false], [true, true, true, true, true, false, false, false, false, false], [true, true, true, true, true, true, false, false, false, false], [true, true, true, true, true, true, true, false, false, false], [true, true, true, true, true, true, true, true, false, false], [true, true, true, true, true, true, true, true, true, false], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true], [true, true, true, true, true, true, true, true, true, true]]> : tensor<10x10xi1> +// %0 = stablehlo.select %c_0, %arg0, %cst : tensor<10x10xi1>, tensor<10x10xf32> +// %1 = stablehlo.select %c, %0, %cst : tensor<10x10xi1>, tensor<10x10xf32> +// return %1 : tensor<10x10xf32> : tensor<10x10xi1>, tensor<10x10xf32> +// } + +func.func @main2(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %c = stablehlo.constant dense<2> : tensor<10x10xi64> + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %c_0 = stablehlo.constant dense<-3> : tensor<10x10xi64> + %0 = stablehlo.iota dim = 1 : tensor<10x10xi64> + %1 = stablehlo.iota dim = 0 : tensor<10x10xi64> + %2 = stablehlo.subtract %1, %c_0 : tensor<10x10xi64> + %3 = stablehlo.compare LE, %0, %2 : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<10x10xi1> + %4 = stablehlo.subtract %1, %c : tensor<10x10xi64> + %5 = stablehlo.compare GE, %0, %4 : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<10x10xi1> + %6 = stablehlo.select %3, %arg0, %cst : tensor<10x10xi1>, tensor<10x10xf32> + %7 = stablehlo.select %5, %6, %cst : tensor<10x10xi1>, tensor<10x10xf32> + return %7 : tensor<10x10xf32> +} diff --git a/test/lit_tests/structured_tensors/newton_schulz.mlir b/test/lit_tests/structured_tensors/newton_schulz.mlir new file mode 100644 index 000000000..5d804cc80 --- /dev/null +++ b/test/lit_tests/structured_tensors/newton_schulz.mlir @@ -0,0 +1,29 @@ +// RUN: enzymexlamlir-opt --structured-matrix-simplify %s | FileCheck %s + +module { + func.func @main(%arg0: tensor<5x4xf32>) -> tensor<5x4xf32> { + %c = stablehlo.constant dense<0> : tensor + %c_0 = stablehlo.constant dense<5> : tensor + %c_1 = stablehlo.constant dense<1> : tensor + %cst = stablehlo.constant dense<2.031500e+00> : tensor<4x4xf32> + %cst_2 = stablehlo.constant dense<-4.775000e+00> : tensor<4x4xf32> + %cst_3 = stablehlo.constant dense<3.444500e+00> : tensor<5x4xf32> + %0:2 = stablehlo.while(%iterArg = %c, %iterArg_4 = %arg0) : tensor, tensor<5x4xf32> + cond { + %1 = stablehlo.compare LT, %iterArg, %c_0 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.add %iterArg, %c_1 : tensor + %2 = stablehlo.dot_general %iterArg_4, %iterArg_4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<4x4xf32> + %3 = stablehlo.multiply %cst_2, %2 : tensor<4x4xf32> + %4 = stablehlo.multiply %cst, %2 : tensor<4x4xf32> + %5 = stablehlo.dot_general %4, %2, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %6 = stablehlo.add %3, %5 : tensor<4x4xf32> + %7 = stablehlo.multiply %cst_3, %iterArg_4 : tensor<5x4xf32> + %8 = stablehlo.dot_general %iterArg_4, %6, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<5x4xf32>, tensor<4x4xf32>) -> tensor<5x4xf32> + %9 = stablehlo.add %7, %8 : tensor<5x4xf32> + stablehlo.return %1, %9 : tensor, tensor<5x4xf32> + } + return %0#1 : tensor<5x4xf32> + } +} diff --git a/test/lit_tests/structured_tensors/symmetric.mlir b/test/lit_tests/structured_tensors/symmetric.mlir new file mode 100644 index 000000000..6943ecda7 --- /dev/null +++ b/test/lit_tests/structured_tensors/symmetric.mlir @@ -0,0 +1,14 @@ +// RUN: enzymexlamlir-opt --structured-matrix-simplify %s | FileCheck %s + +module { + func.func @symmetric(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %c0 = stablehlo.constant dense<0.000000e+00> : tensor<2x2xf32> + %c1 = stablehlo.constant dense<1.000000e+00> : tensor<2x2xf32> + %cst = stablehlo.add %c0, %c1 : tensor<2x2xf32> + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = stablehlo.add %0, %arg0 : tensor<2x2xf32> + %2 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> + %3 = stablehlo.add %1, %2 : tensor<2x2xf32> + return %3 : tensor<2x2xf32> + } +}