From 22ffa0192d8a7739c90a3f5f3ccea2d55812c589 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Thu, 18 Sep 2025 14:25:55 +0200 Subject: [PATCH 1/7] starting the reggvolution --- src/main/scala/rise/eqsat/Reggvolution.scala | 147 +++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 src/main/scala/rise/eqsat/Reggvolution.scala diff --git a/src/main/scala/rise/eqsat/Reggvolution.scala b/src/main/scala/rise/eqsat/Reggvolution.scala new file mode 100644 index 000000000..4fa51e2a7 --- /dev/null +++ b/src/main/scala/rise/eqsat/Reggvolution.scala @@ -0,0 +1,147 @@ +package rise.eqsat + +/* + This package contains features to translate Rise expressions and rewrites + to an egg-compatible language living in Rust. + + see https://github.com/Bastacyclop/reggvolution + + */ +object Reggvolution { + def sym(s: String): String = s + // s"""sym("$s")""" + + // cascade of apps bearing no types, used to encode many language constructs + // as simple symbol applications + def noTyApp(f: String, args: Iterable[String]): String = { + args.foldLeft(f) { case (acc, arg) => + s"(app $acc $arg)" + // s"App([$acc, $arg])" + } + } + + type Shift = rise.eqsat.Expr.Shift + + def reggvolve(expr: Expr): String = + Expr.reggvolve(expr, (0, 0, 0, 0, 0)) + + object Expr { + def reggvolve(expr: Expr, s: Shift): String = { + val e = expr.node match { + case Var(index) => s"%${index + s._1}" + // s"Var(${index + s._1})" + case App(f, e) => s"(app ${reggvolve(f, s)} ${reggvolve(e, s)})" + // s"App([${reggvolve(f, s)}, ${reggvolve(e, s)}])" + case NatApp(f, x) => s"(app ${reggvolve(f, s)} ${Nat.reggvolve(x, s)})" + case DataApp(f, x) => s"(app ${reggvolve(f, s)} ${DataType.reggvolve(x, s)})" + case AddrApp(f, x) => s"(app ${reggvolve(f, s)} ${Addr.reggvolve(x, s)})" + case AppNatToNat(f, x) => s"(app ${reggvolve(f, s)} ${NatToNat.reggvolve(x, s)})" + case Lambda(e) => + val s2 = (s._1, s._2 + 1, s._3 + 1, s._4 + 1, s._5 + 1) + s"(lam ${reggvolve(e, s2)})" + case NatLambda(e) => + val s2 = (s._1 + 1, s._2, s._3 + 1, s._4 + 1, s._5 + 1) + s"(lam ${reggvolve(e, s2)})" + case DataLambda(e) => + val s2 = (s._1 + 1, s._2 + 1, s._3, s._4 + 1, s._5 + 1) + s"(lam ${reggvolve(e, s2)})" + case AddrLambda(e) => + val s2 = (s._1 + 1, s._2 + 1, s._3 + 1, s._4, s._5 + 1) + s"(lam ${reggvolve(e, s2)})" + case LambdaNatToNat(e) => + val s2 = (s._1 + 1, s._2 + 1, s._3 + 1, s._4 + 1, s._5) + s"(lam ${reggvolve(e, s)})" + case Literal(d) => + import rise.core.semantics._ + + d match { + case BoolData(true) => sym("true") + case BoolData(false) => sym("false") + case IntData(i) => i.toString() // s"Integer($i)" + case FloatData(f) => f.toString() // s"Float($f)" + case DoubleData(d) => d.toString() // s"Double($d)" + case _=> throw new Exception(s"not supporting literal $d yet") + } + case NatLiteral(n) => Nat.reggvolve(n, s) + case IndexLiteral(i, n) => noTyApp(sym("idxL"), List(i, n).map(Nat.reggvolve(_, s))) + case Primitive(p) => sym(p.name) + case Composition(f, g) => ??? + } + val t = Type.reggvolve(expr.t, s) + // s"TypeOf([$e, $t])" + s"(typeOf $e $t)" + } + } + + object Type { + def reggvolve(ty: Type, s: Shift): String = { + ty.node match { + case dt: DataTypeNode[Nat, DataType] => + DataType.reggvolve(rise.eqsat.DataType(dt), s) + case FunType(a, b) => noTyApp(sym("fun"), List(a, b).map(reggvolve(_, s))) + // TODO: do we need to remember the arg kind as a type ? + case NatFunType(t) => + val s2 = (s._1 + 1, s._2, s._3 + 1, s._4 + 1, s._5 + 1) + s"(lam ${reggvolve(t, s2)})" + case DataFunType(t) => + val s2 = (s._1 + 1, s._2 + 1, s._3, s._4 + 1, s._5 + 1) + s"(lam ${reggvolve(t, s2)})" + case AddrFunType(t) => + val s2 = (s._1 + 1, s._2 + 1, s._3 + 1, s._4, s._5 + 1) + s"(lam ${reggvolve(t, s2)})" + case NatToNatFunType(t) => + val s2 = (s._1 + 1, s._2 + 1, s._3 + 1, s._4 + 1, s._5) + s"(lam ${reggvolve(t, s2)})" + } + } + } + + object Nat { + def reggvolve(n: Nat, s: Shift): String = { + n.node match { + case NatVar(index) => s"%${index + s._2}" + case NatCst(value) => value.toString() + case NatNegInf => ??? + case NatPosInf => ??? + case NatAdd(a, b) => noTyApp(sym("add"), List(a, b).map(reggvolve(_, s))) + case NatMul(a, b) => noTyApp(sym("mul"), List(a, b).map(reggvolve(_, s))) + case NatPow(a, b) => noTyApp(sym("pow"), List(a, b).map(reggvolve(_, s))) + case NatMod(a, b) => noTyApp(sym("mod"), List(a, b).map(reggvolve(_, s))) + case NatIntDiv(a, b) => noTyApp(sym("floorDiv"), List(a, b).map(reggvolve(_, s))) + case NatToNatApp(f, n) => ??? + } + } + } + + object DataType { + def reggvolve(dty: DataType, s: Shift): String = { + dty.node match { + case DataTypeVar(index) => s"%${index + s._3}" + case ScalarType(s) => sym(s.toString()) + case NatType => sym("natT") + case IndexType(n) => noTyApp(sym("idxT"), List(Nat.reggvolve(n, s))) + case PairType(dt1, dt2) => noTyApp(sym("pairT"), List(dt1, dt2).map(reggvolve(_, s))) + case ArrayType(n, et) => noTyApp(sym("arrT"), List(Nat.reggvolve(n, s), reggvolve(et, s))) + case VectorType(n, et) => noTyApp(sym("vecT"), List(Nat.reggvolve(n, s), reggvolve(et, s))) + } + } + } + + object Addr { + def reggvolve(a: Address, s: Shift): String = { + a match { + case AddressVar(index) => s"%${index + s._4}" + case Global => sym("global") + case Local => sym("local") + case Private => sym("private") + case Constant => sym("constant") + } + } + } + + object NatToNat { + def reggvolve(n: NatToNat, s: Shift): String = { + ??? + } + } +} From 11eef6a66a326def53dc032fb178a6171be6db57 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Thu, 18 Sep 2025 17:32:41 +0200 Subject: [PATCH 2/7] support patterns and simple rules --- src/main/scala/benchmarks/eqsat/mm.scala | 23 +++ src/main/scala/rise/eqsat/Pattern.scala | 167 ++++++++++--------- src/main/scala/rise/eqsat/Reggvolution.scala | 104 +++++++----- 3 files changed, 173 insertions(+), 121 deletions(-) diff --git a/src/main/scala/benchmarks/eqsat/mm.scala b/src/main/scala/benchmarks/eqsat/mm.scala index 309e70121..a858aaeed 100644 --- a/src/main/scala/benchmarks/eqsat/mm.scala +++ b/src/main/scala/benchmarks/eqsat/mm.scala @@ -261,6 +261,29 @@ object mm { private def baseline(): GuidedSearch.Result = { val start = mm + + println(Reggvolution.reggvolve(Expr.fromNamed(mm))) + println("---- RULES ----") + var success = 0 + var fail = 0 + def tryReggvolveRule(r: Rewrite) = { + try { + println(Reggvolution.reggvolve(r)) + success = success + 1 + } catch { + case e: Exception => + fail = fail + 1 + println(s"could not reggvolve rule ${r.name}: ${e}") + } + } + splitStepBENF.rules.foreach(tryReggvolveRule) + reorderStepBENF.rules.foreach(tryReggvolveRule) + copyStep.rules.foreach(tryReggvolveRule) + loweringStep.rules.foreach(tryReggvolveRule) + println("----") + println(s"translated: ${success}/${success+fail}") + + throw new Exception("done") val steps = Seq( emptyStep withRules Seq(rules.reduceSeq, rules.reduceSeqMapFusion) diff --git a/src/main/scala/rise/eqsat/Pattern.scala b/src/main/scala/rise/eqsat/Pattern.scala index e5f5db05f..f06b3aea7 100644 --- a/src/main/scala/rise/eqsat/Pattern.scala +++ b/src/main/scala/rise/eqsat/Pattern.scala @@ -13,60 +13,8 @@ object Pattern { Pattern(PatternNode(pnode), TypePattern.fromType(e.t)) } - implicit def patternToApplier(pattern: Pattern): Applier = new Applier { - override def toString: String = pattern.toString - - override def patternVars(): Set[Any] = pattern.patternVars() - - override def requiredAnalyses(): (Set[Analysis], Set[TypeAnalysis]) = - (Set(), Set()) - - override def applyOne(egraph: EGraph, - eclass: EClassId, - shc: Substs)( - subst: shc.Subst): Vec[EClassId] = { - def missingRhsTy[T](): T = throw new Exception("unknown type on right-hand side") - def pat(p: Pattern): EClassId = { - p.p match { - case w: PatternVar => shc.get(w, subst) - case PatternNode(n) => - val enode = n.map(pat, nat, data, addr) - egraph.add(enode, `type`(p.t)) - } - } - def nat(p: NatPattern): NatId = { - p match { - case w: NatPatternVar => shc.get(w, subst) - case NatPatternNode(n) => egraph.add(n.map(nat)) - case NatPatternAny => missingRhsTy() - } - } - def data(pat: DataTypePattern): DataTypeId = { - pat match { - case w: DataTypePatternVar => shc.get(w, subst) - case DataTypePatternNode(n) => egraph.add(n.map(nat, data)) - case DataTypePatternAny => missingRhsTy() - } - } - def `type`(pat: TypePattern): TypeId = { - pat match { - case w: TypePatternVar => shc.get(w, subst) - case TypePatternNode(n) => egraph.add(n.map(`type`, nat, data)) - case TypePatternAny => missingRhsTy() - case dtp: DataTypePattern => data(dtp) - } - } - def addr(pat: AddressPattern): Address = { - pat match { - case w: AddressPatternVar => shc.get(w, subst) - case AddressPatternNode(n) => n - case AddressPatternAny => missingRhsTy() - } - } - - Vec(pat(pattern)) - } - } + implicit def patternToApplier(pattern: Pattern): Applier = + PatternApplier(pattern) } sealed trait PatternVarOrNode @@ -106,37 +54,94 @@ case class CompiledPattern(pat: Pattern, prog: ematching.Program) { object CompiledPattern { implicit def patternToSearcher(cpat: CompiledPattern) - : Searcher = new Searcher { - override def toString: String = cpat.toString - - override def patternVars(): Set[Any] = cpat.pat.patternVars() - - override def search(egraph: EGraph, - shc: Substs, - ): Vec[SearchMatches[shc.Subst]] = { - cpat.pat.p match { - case PatternNode(node) => - egraph.classesByMatch.get(node.matchHash()) match { - case None => Vec.empty - case Some(ids) => - ids.iterator.flatMap(id => searchEClass(egraph, shc, id)).to(Vec) - } - case PatternVar(_) => egraph.classes.keysIterator - .flatMap(id => searchEClass(egraph, shc, id)).to(Vec) - } - } + : Searcher = CompiledPatternSearcher(cpat.pat, cpat.prog) + + implicit def patternToApplier(cpat: CompiledPattern): Applier = + Pattern.patternToApplier(cpat.pat) +} + +case class CompiledPatternSearcher(pat: Pattern, prog: ematching.Program) extends Searcher { + override def toString: String = s"CompiledPattern(${pat.toString()})" - override def searchEClass(egraph: EGraph, - shc: Substs, - eclass: EClassId, - ): Option[SearchMatches[shc.Subst]] = { - val substs = cpat.prog.run(egraph, eclass, shc) - if (substs.isEmpty) { None } else { Some(SearchMatches(eclass, substs)) } + override def patternVars(): Set[Any] = pat.patternVars() + + override def search(egraph: EGraph, + shc: Substs, + ): Vec[SearchMatches[shc.Subst]] = { + pat.p match { + case PatternNode(node) => + egraph.classesByMatch.get(node.matchHash()) match { + case None => Vec.empty + case Some(ids) => + ids.iterator.flatMap(id => searchEClass(egraph, shc, id)).to(Vec) + } + case PatternVar(_) => egraph.classes.keysIterator + .flatMap(id => searchEClass(egraph, shc, id)).to(Vec) } } - implicit def patternToApplier(cpat: CompiledPattern): Applier = - Pattern.patternToApplier(cpat.pat) + override def searchEClass(egraph: EGraph, + shc: Substs, + eclass: EClassId, + ): Option[SearchMatches[shc.Subst]] = { + val substs = prog.run(egraph, eclass, shc) + if (substs.isEmpty) { None } else { Some(SearchMatches(eclass, substs)) } + } +} + +case class PatternApplier(pattern: Pattern) extends Applier { + override def toString: String = pattern.toString + + override def patternVars(): Set[Any] = pattern.patternVars() + + override def requiredAnalyses(): (Set[Analysis], Set[TypeAnalysis]) = + (Set(), Set()) + + override def applyOne(egraph: EGraph, + eclass: EClassId, + shc: Substs)( + subst: shc.Subst): Vec[EClassId] = { + def missingRhsTy[T](): T = throw new Exception("unknown type on right-hand side") + def pat(p: Pattern): EClassId = { + p.p match { + case w: PatternVar => shc.get(w, subst) + case PatternNode(n) => + val enode = n.map(pat, nat, data, addr) + egraph.add(enode, `type`(p.t)) + } + } + def nat(p: NatPattern): NatId = { + p match { + case w: NatPatternVar => shc.get(w, subst) + case NatPatternNode(n) => egraph.add(n.map(nat)) + case NatPatternAny => missingRhsTy() + } + } + def data(pat: DataTypePattern): DataTypeId = { + pat match { + case w: DataTypePatternVar => shc.get(w, subst) + case DataTypePatternNode(n) => egraph.add(n.map(nat, data)) + case DataTypePatternAny => missingRhsTy() + } + } + def `type`(pat: TypePattern): TypeId = { + pat match { + case w: TypePatternVar => shc.get(w, subst) + case TypePatternNode(n) => egraph.add(n.map(`type`, nat, data)) + case TypePatternAny => missingRhsTy() + case dtp: DataTypePattern => data(dtp) + } + } + def addr(pat: AddressPattern): Address = { + pat match { + case w: AddressPatternVar => shc.get(w, subst) + case AddressPatternNode(n) => n + case AddressPatternAny => missingRhsTy() + } + } + + Vec(pat(pattern)) + } } object PatternDSL { diff --git a/src/main/scala/rise/eqsat/Reggvolution.scala b/src/main/scala/rise/eqsat/Reggvolution.scala index 4fa51e2a7..9a1a5502a 100644 --- a/src/main/scala/rise/eqsat/Reggvolution.scala +++ b/src/main/scala/rise/eqsat/Reggvolution.scala @@ -23,19 +23,36 @@ object Reggvolution { type Shift = rise.eqsat.Expr.Shift def reggvolve(expr: Expr): String = - Expr.reggvolve(expr, (0, 0, 0, 0, 0)) + // NOTE: could define reggvolution for generic nodes, but this is simpler + reggvolve(Pattern.fromExpr(expr)) - object Expr { - def reggvolve(expr: Expr, s: Shift): String = { - val e = expr.node match { + def reggvolve(rw: Rewrite): String = { + val lhs = rw.searcher match { + case cp: CompiledPatternSearcher => reggvolve(cp.pat) + case _ => throw new Exception(s"could not reggvolve searcher: ${rw.searcher.getClass()}") + } + val rhs = rw.applier match { + case cp: PatternApplier => reggvolve(cp.pattern) + case _ => throw new Exception(s"could not reggvolve applier: ${rw.applier.getClass()}") + } + s"""rewrite!("${rw.name}", "${lhs}" => "${rhs}")""" + } + + def reggvolve(pat: Pattern): String = + reggvolve(pat, (0, 0, 0, 0, 0)) + + def reggvolve(pat: Pattern, s: Shift): String = { + val e = pat.p match { + case PatternVar(index) => s"?e${index}" + case PatternNode(node) => node match { case Var(index) => s"%${index + s._1}" // s"Var(${index + s._1})" case App(f, e) => s"(app ${reggvolve(f, s)} ${reggvolve(e, s)})" // s"App([${reggvolve(f, s)}, ${reggvolve(e, s)}])" - case NatApp(f, x) => s"(app ${reggvolve(f, s)} ${Nat.reggvolve(x, s)})" - case DataApp(f, x) => s"(app ${reggvolve(f, s)} ${DataType.reggvolve(x, s)})" - case AddrApp(f, x) => s"(app ${reggvolve(f, s)} ${Addr.reggvolve(x, s)})" - case AppNatToNat(f, x) => s"(app ${reggvolve(f, s)} ${NatToNat.reggvolve(x, s)})" + case NatApp(f, x) => s"(app ${reggvolve(f, s)} ${reggvolve(x, s)})" + case DataApp(f, x) => s"(app ${reggvolve(f, s)} ${reggvolve(x, s)})" + case AddrApp(f, x) => s"(app ${reggvolve(f, s)} ${reggvolve(x, s)})" + case AppNatToNat(f, x) => s"(app ${reggvolve(f, s)} ${reggvolve(x, s)})" case Lambda(e) => val s2 = (s._1, s._2 + 1, s._3 + 1, s._4 + 1, s._5 + 1) s"(lam ${reggvolve(e, s2)})" @@ -62,22 +79,26 @@ object Reggvolution { case DoubleData(d) => d.toString() // s"Double($d)" case _=> throw new Exception(s"not supporting literal $d yet") } - case NatLiteral(n) => Nat.reggvolve(n, s) - case IndexLiteral(i, n) => noTyApp(sym("idxL"), List(i, n).map(Nat.reggvolve(_, s))) + case NatLiteral(n) => reggvolve(n, s) + case IndexLiteral(i, n) => noTyApp(sym("idxL"), List(i, n).map(reggvolve(_, s))) case Primitive(p) => sym(p.name) case Composition(f, g) => ??? } - val t = Type.reggvolve(expr.t, s) - // s"TypeOf([$e, $t])" - s"(typeOf $e $t)" } + val t = reggvolve(pat.t, s) + // s"TypeOf([$e, $t])" + s"(typeOf $e $t)" } - object Type { - def reggvolve(ty: Type, s: Shift): String = { - ty.node match { - case dt: DataTypeNode[Nat, DataType] => - DataType.reggvolve(rise.eqsat.DataType(dt), s) + def reggvolve(ty: TypePattern, s: Shift): String = { + ty match { + case TypePatternVar(index) => s"?t${index}" + case DataTypePatternVar(index) => s"?dt${index}" + case TypePatternAny => "?" + case DataTypePatternAny => "?" + case TypePatternNode(n) => n match { + case dt: DataTypeNode[_, _] => + reggvolve(rise.eqsat.DataTypePatternNode(dt), s) case FunType(a, b) => noTyApp(sym("fun"), List(a, b).map(reggvolve(_, s))) // TODO: do we need to remember the arg kind as a type ? case NatFunType(t) => @@ -93,12 +114,16 @@ object Reggvolution { val s2 = (s._1 + 1, s._2 + 1, s._3 + 1, s._4 + 1, s._5) s"(lam ${reggvolve(t, s2)})" } + // FIXME: this construct is redundant ??? + case dtn: DataTypePatternNode => reggvolve(dtn, s) } } - - object Nat { - def reggvolve(n: Nat, s: Shift): String = { - n.node match { + + def reggvolve(n: NatPattern, s: Shift): String = { + n match { + case NatPatternVar(index) => s"?n${index}" + case NatPatternAny => "?" + case NatPatternNode(n) => n match { case NatVar(index) => s"%${index + s._2}" case NatCst(value) => value.toString() case NatNegInf => ??? @@ -113,23 +138,23 @@ object Reggvolution { } } - object DataType { - def reggvolve(dty: DataType, s: Shift): String = { - dty.node match { - case DataTypeVar(index) => s"%${index + s._3}" - case ScalarType(s) => sym(s.toString()) - case NatType => sym("natT") - case IndexType(n) => noTyApp(sym("idxT"), List(Nat.reggvolve(n, s))) - case PairType(dt1, dt2) => noTyApp(sym("pairT"), List(dt1, dt2).map(reggvolve(_, s))) - case ArrayType(n, et) => noTyApp(sym("arrT"), List(Nat.reggvolve(n, s), reggvolve(et, s))) - case VectorType(n, et) => noTyApp(sym("vecT"), List(Nat.reggvolve(n, s), reggvolve(et, s))) - } + def reggvolve(dty: DataTypePatternNode, s: Shift): String = { + dty.n match { + case DataTypeVar(index) => s"%${index + s._3}" + case ScalarType(s) => sym(s.toString()) + case NatType => sym("natT") + case IndexType(n) => noTyApp(sym("idxT"), List(reggvolve(n, s))) + case PairType(dt1, dt2) => noTyApp(sym("pairT"), List(dt1, dt2).map(reggvolve(_, s))) + case ArrayType(n, et) => noTyApp(sym("arrT"), List(reggvolve(n, s), reggvolve(et, s))) + case VectorType(n, et) => noTyApp(sym("vecT"), List(reggvolve(n, s), reggvolve(et, s))) } } - object Addr { - def reggvolve(a: Address, s: Shift): String = { - a match { + def reggvolve(a: AddressPattern, s: Shift): String = { + a match { + case AddressPatternVar(index) => s"?a${index}" + case AddressPatternAny => "?" + case AddressPatternNode(n) => n match { case AddressVar(index) => s"%${index + s._4}" case Global => sym("global") case Local => sym("local") @@ -139,9 +164,8 @@ object Reggvolution { } } - object NatToNat { - def reggvolve(n: NatToNat, s: Shift): String = { - ??? - } + def reggvolve(n: NatToNatNode[NatPattern], s: Shift): String = { + ??? } + } From 35f5e8ffc262196b71202486e47bb83fe58055b2 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Thu, 18 Sep 2025 17:39:05 +0200 Subject: [PATCH 3/7] typo --- src/main/scala/rise/eqsat/Reggvolution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/rise/eqsat/Reggvolution.scala b/src/main/scala/rise/eqsat/Reggvolution.scala index 9a1a5502a..392485c0b 100644 --- a/src/main/scala/rise/eqsat/Reggvolution.scala +++ b/src/main/scala/rise/eqsat/Reggvolution.scala @@ -35,7 +35,7 @@ object Reggvolution { case cp: PatternApplier => reggvolve(cp.pattern) case _ => throw new Exception(s"could not reggvolve applier: ${rw.applier.getClass()}") } - s"""rewrite!("${rw.name}", "${lhs}" => "${rhs}")""" + s"""rewrite!("${rw.name}"; "${lhs}" => "${rhs}")""" } def reggvolve(pat: Pattern): String = From 4d8867929fdb4896fd0d297bfa7a9907078f5672 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Thu, 18 Sep 2025 17:45:11 +0200 Subject: [PATCH 4/7] remove duplicates --- src/main/scala/benchmarks/eqsat/mm.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/main/scala/benchmarks/eqsat/mm.scala b/src/main/scala/benchmarks/eqsat/mm.scala index a858aaeed..1e190b962 100644 --- a/src/main/scala/benchmarks/eqsat/mm.scala +++ b/src/main/scala/benchmarks/eqsat/mm.scala @@ -264,16 +264,20 @@ object mm { println(Reggvolution.reggvolve(Expr.fromNamed(mm))) println("---- RULES ----") + var visited = Set[Rewrite]() var success = 0 var fail = 0 def tryReggvolveRule(r: Rewrite) = { - try { - println(Reggvolution.reggvolve(r)) - success = success + 1 - } catch { - case e: Exception => - fail = fail + 1 - println(s"could not reggvolve rule ${r.name}: ${e}") + if (!visited.contains(r)) { + visited = visited + r + try { + println(Reggvolution.reggvolve(r)) + success = success + 1 + } catch { + case e: Exception => + fail = fail + 1 + println(s"could not reggvolve rule ${r.name}: ${e}") + } } } splitStepBENF.rules.foreach(tryReggvolveRule) From 059e8db77d493f48cfd8c9fa38213f9c462df787 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Tue, 30 Sep 2025 15:08:52 +0200 Subject: [PATCH 5/7] reggvovle named rewrite (support appliers) --- src/main/scala/benchmarks/eqsat/mm.scala | 201 +++++++-- src/main/scala/rise/eqsat/NamedRewrite.scala | 83 ++-- src/main/scala/rise/eqsat/Reggvolution.scala | 412 ++++++++++++++++++- src/main/scala/rise/eqsat/Rewrite.scala | 170 ++++---- 4 files changed, 713 insertions(+), 153 deletions(-) diff --git a/src/main/scala/benchmarks/eqsat/mm.scala b/src/main/scala/benchmarks/eqsat/mm.scala index 1e190b962..e17f11250 100644 --- a/src/main/scala/benchmarks/eqsat/mm.scala +++ b/src/main/scala/benchmarks/eqsat/mm.scala @@ -261,33 +261,6 @@ object mm { private def baseline(): GuidedSearch.Result = { val start = mm - - println(Reggvolution.reggvolve(Expr.fromNamed(mm))) - println("---- RULES ----") - var visited = Set[Rewrite]() - var success = 0 - var fail = 0 - def tryReggvolveRule(r: Rewrite) = { - if (!visited.contains(r)) { - visited = visited + r - try { - println(Reggvolution.reggvolve(r)) - success = success + 1 - } catch { - case e: Exception => - fail = fail + 1 - println(s"could not reggvolve rule ${r.name}: ${e}") - } - } - } - splitStepBENF.rules.foreach(tryReggvolveRule) - reorderStepBENF.rules.foreach(tryReggvolveRule) - copyStep.rules.foreach(tryReggvolveRule) - loweringStep.rules.foreach(tryReggvolveRule) - println("----") - println(s"translated: ${success}/${success+fail}") - - throw new Exception("done") val steps = Seq( emptyStep withRules Seq(rules.reduceSeq, rules.reduceSeqMapFusion) @@ -648,6 +621,7 @@ object mm { } def main(args: Array[String]): Unit = { + Reggvolve.init_rules() // val names = Set(args(0)) // fs.filter { case (k, _) => names(k) } @@ -690,3 +664,176 @@ object mm { } } } + +object Reggvolve { + def init_rules(): Unit = { + import rise.core.{primitives => rcp} + import rise.core.types.{Nat, DataType, AddressSpace} + import NamedRewriteDSL._ + + println(Reggvolution.reggvolve(Expr.fromNamed(benchmarks.eqsat.mm.mm))) + + println("---- RULES ----") + + println(Reggvolution.reggvolveNamedRewrite("map-fission", + app(map, lam("x", app("f", "gx" :: ("dt": DataType)))) + --> + lam("in", app(app(map, "f"), app(app(map, lam("x", "gx")), "in"))), + Seq("f" notFree "x") + )) + println(Reggvolution.reggvolveNamedRewrite("reduce-seq", + reduce --> rcp.reduceSeq.primitive + )) + println(Reggvolution.reggvolveNamedRewrite("eliminate-map-identity", + app(map, lam("x", "x")) + --> + lam("y", "y") + )) + println(Reggvolution.reggvolveNamedRewrite("reduce-seq-map-fusion", + app(app(app(rcp.reduceSeq.primitive, "f"), "init"), app(app(map, "g"), "in")) + --> + app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("x", + app(app("f", "acc"), app("g", "x"))))), "init"), "in") + )) + def splitJoin(n: Int) = Reggvolution.reggvolveNamedRewrite(s"split-join-$n", + app(app(map, "f"), "in") + --> + app(join, app(app(map, app(map, "f")), app(nApp(split, n), "in"))) + ) + println(splitJoin(32)) + def splitJoin2M(n: Int) = Reggvolution.reggvolveNamedRewrite(s"split-join-2m-$n", + app(app(map, app(map, app(map, "f"))), "in") + --> + app(app(map, app(map, join)), app(app(map, app(map, app(map, app(map, "f")))), app(app(map, app(map, nApp(split, n))), "in"))) + ) + println(splitJoin2M(32)) + def blockedReduce(n: Int) = Reggvolution.reggvolveNamedRewrite(s"blocked-reduce-$n", + app(app(app(reduce, "op" :: ("a" ->: "a" ->: t("a"))), "init"), "arg") + --> + app(app(app(rcp.reduceSeq.primitive, + lam("acc", lam("y", app(app("op", "acc"), + app(app(app(reduce, "op"), "init"), "y"))))), + "init"), app(nApp(split, n), "arg")) + ) + println(blockedReduce(4)) + println(Reggvolution.reggvolveNamedRewrite("split-before-map", + app(nApp(split, "n"), app(app(map, "f"), "in")) + --> + app(app(map, app(map, "f")), app(nApp(split, "n"), "in")) + )) + println(Reggvolution.reggvolveNamedRewrite("reduce-seq-map-fission", + app(app(rcp.reduceSeq.primitive, lam("acc", lam("y", + app(app("op", "acc"), "gy" :: ("dt": DataType))))), "init") + --> + lam("in", app(app(app(rcp.reduceSeq.primitive, "op"), "init"), + app(app(map, lam("y", "gy")), "in"))), + Seq("op" notFree "y") + )) + println(Reggvolution.reggvolveNamedRewrite("lift-reduce-seq", + app(map, app(app(rcp.reduceSeq.primitive, "op"), "init")) + --> + lam("in", + app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("y", + app(app(map, lam("z", app(app("op", app(fst, "z")), app(snd, "z")))), + app(app(zip, "acc"), "y")) + ))), app(rcp.generate.primitive, lam("i", "init"))), + app(transpose, "in"))) + )) + println(Reggvolution.reggvolveNamedRewrite("lift-reduce-seq-2", + app(map, lam("x", app(app(add, app(fst, "x")), + app(app(app(rcp.reduceSeq.primitive, "op"), lf32(0)), app(snd, "x"))))) + --> + lam("in", app(lam("uz", + app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("y", + app(app(map, lam("z", app(app("op", app(fst, "z")), app(snd, "z")))), + app(app(zip, "acc"), "y")) + ))), app(fst, "uz")), app(transpose, app(snd, "uz")))), + app(unzip, "in"))) + )) + println(Reggvolution.reggvolveNamedRewrite("lift-reduce-seq-3", + (app(map, lam("x", + app(app(app(rcp.reduceSeq.primitive, "op"), app(fst, app(unzip, "x"))), + app(transpose, app(snd, app(unzip, "x"))))))) + --> + lam("in", + app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("y", + app(app(map, lam("z", app(app("op", app(fst, "z")), app(snd, "z")))), + app(app(zip, "acc"), "y")) + ))), app(fst, app(unzip, app(app(map, unzip), "in")))), + app(transpose, app(app(map, transpose), app(snd, app(unzip, app(app(map, unzip), "in"))))))), + Seq("op" notFree "x") + )) + println(Reggvolution.reggvolveNamedRewrite("transpose-around-map-map-f-1m", + app(app(map, app(map, app(map, "f"))), "in") + --> + app(app(map, transpose), app(app(map, app(map, app(map, "f"))), app(app(map, transpose), "in"))) + )) + println(Reggvolution.reggvolveNamedRewrite("store-to-mem", + ("in" :: ("dt": DataType)) + --> + app(app(rcp.let.primitive, app(rcp.toMem.primitive, "in")), lam("x", "x")) + )) + def splitJoin2(n: Int) = Reggvolution.reggvolveNamedRewrite(s"split-join-2-$n", + ("in" :: (`?n``.``?dt`)) + --> + app(join, app(nApp(split, n), "in")) + ) + splitJoin2(32) + println(Reggvolution.reggvolveNamedRewrite("map-array", + ("x" :: (`?n``.``?dt`)) + --> + app(app(map, lam("y", "y")), "x") + )) + println(Reggvolution.reggvolveNamedRewrite("map-fusion", + app(app(map, "f"), app(app(map, "g"), "in")) + --> + app(app(map, lam("x", app("f", app("g", "x")))), "in") + )) + println(Reggvolution.reggvolveNamedRewrite("map-eta-abstraction", + app(map, "f") --> lam("x", app(app(map, "f"), "x")) + )) + object vectorize { + def after(n: Int, dt: DataType) = Reggvolution.reggvolveNamedRewrite(s"vec-$n-after-$dt", + // TODO: if m % n == 0 ? + ("e" :: (("m": Nat)`.`dt)) + --> + app(asScalar, app(nApp(asVector, n), "e")) + ) + } + println(vectorize.after(32, f32)) + println(vectorize.after(32, f32 x f32)) + println(vectorize.after(32, f32 x (f32 x f32))) + println(Reggvolution.reggvolveNamedRewrite("vec-before-map-f32", + app(nApp(asVector, "n"), app(app(map, "f" :: f32 ->: f32), ("in": Pattern))) + --> + app(app(map, "fV"), app(nApp(asVector, "n"), "in")), + Seq(vectorizeScalarFun("f", "n", "fV")) + )) + println(Reggvolution.reggvolveNamedRewrite("vec-before-map-f32xf32", + app(nApp(asVector, "n"), app(app(map, "f" :: (f32 x f32) ->: f32), ("in": Pattern))) + --> + app(app(map, "fV"), + app(app(zip, app(nApp(asVector, "n"), app(fst, app(unzip, "in")))), + app(nApp(asVector, "n"), app(snd, app(unzip, "in"))))), + Seq(vectorizeScalarFun("f", "n", "fV")) + )) + println(Reggvolution.reggvolveNamedRewrite("vec-before-map-f32x-f32xf32", + app(nApp(asVector, "n"), app(app(map, "f" :: (f32 x (f32 x f32)) ->: f32), ("in": Pattern))) + --> + app(app(map, "fV"), + app(app(zip, app(nApp(asVector, "n"), app(fst, app(unzip, "in")))), + app(app(zip, app(nApp(asVector, "n"), app(fst, app(unzip, app(snd, app(unzip, "in")))))), + app(nApp(asVector, "n"), app(snd, app(unzip, app(snd, app(unzip, "in")))))))), + Seq(vectorizeScalarFun("f", "n", "fV")) + )) + println(Reggvolution.reggvolveNamedRewrite("reduce-seq-unroll", + rcp.reduceSeq.primitive --> rcp.reduceSeqUnroll.primitive + )) + println(Reggvolution.reggvolveNamedRewrite("map-par", + map --> rise.openMP.primitives.mapPar.primitive + )) + + println("----") + throw new Exception("done") + } +} diff --git a/src/main/scala/rise/eqsat/NamedRewrite.scala b/src/main/scala/rise/eqsat/NamedRewrite.scala index e1bcf307d..27245f918 100644 --- a/src/main/scala/rise/eqsat/NamedRewrite.scala +++ b/src/main/scala/rise/eqsat/NamedRewrite.scala @@ -37,12 +37,11 @@ object NamedRewrite { } } - def init(name: String, - rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), - parameters: Seq[NamedRewrite.Parameter] = Seq(), - ): Rewrite = { + def typeRule( + rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), + parameters: Seq[NamedRewrite.Parameter] + ): (rc.Expr, Map[String, rct.ExprType], Set[rct.Kind.Identifier], rc.Expr) = { import rise.core.DSL.infer - import arithexpr.{arithmetic => ae} val (lhs, rhs) = rule val untypedFreeV = infer.collectFreeEnv(lhs).map { case (name, t) => @@ -61,15 +60,48 @@ object NamedRewrite { } val freeV = freeV1 ++ freeV2 val typedRhs = infer(rc.TypeAnnotation(rhs, typedLhs.t), freeV, freeT) + (typedLhs, freeV, freeT, typedRhs) + } + + trait PatVarStatus + case object Unknown extends PatVarStatus + case object Known extends PatVarStatus + // both known and coherent with other shifts + case object ShiftCoherent extends PatVarStatus + + // from var name to var index and a status depending on local index shift + type PatternVarMap[S, V] = HashMap[String, HashMap[S, (V, PatVarStatus)]] + + def makePatVar[S, V]( + name: String, + shift: S, + pvm: PatternVarMap[S, V], + constructor: Int => V, + status: PatVarStatus + ): V = { + val shiftMap = pvm.getOrElseUpdate(name, HashMap()) + val (pv, previousStatus) = shiftMap.getOrElseUpdate(shift, { + val pvCount = pvm.values.map(m => m.size).sum + (constructor(pvCount), Unknown) + }) + val updatedStatus = (previousStatus, status) match { + case (Unknown, s) => s + case (s, Unknown) => s + case (Known, Known) => Known + case t => throw new Exception(s"did not expect $t") + } + shiftMap(shift) = (pv, updatedStatus) + pv + } + + def init(name: String, + rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), + parameters: Seq[NamedRewrite.Parameter] = Seq(), + ): Rewrite = { + import arithexpr.{arithmetic => ae} - trait PatVarStatus - case object Unknown extends PatVarStatus - case object Known extends PatVarStatus - // both known and coherent with other shifts - case object ShiftCoherent extends PatVarStatus + val (typedLhs, freeV, freeT, typedRhs) = typeRule(rule, parameters) - // from var name to var index and a status depending on local index shift - type PatternVarMap[S, V] = HashMap[String, HashMap[S, (V, PatVarStatus)]] val patVars: PatternVarMap[Expr.Shift, PatternVar] = HashMap() val natPatVars: PatternVarMap[Nat.Shift, NatPatternVar] = HashMap() val dataTypePatVars: PatternVarMap[Type.Shift, DataTypePatternVar] = HashMap() @@ -81,26 +113,6 @@ object NamedRewrite { val boundVarToShift = HashMap[String, Expr.Shift]() - def makePatVar[S, V](name: String, - shift: S, - pvm: PatternVarMap[S, V], - constructor: Int => V, - status: PatVarStatus): V = { - val shiftMap = pvm.getOrElseUpdate(name, HashMap()) - val (pv, previousStatus) = shiftMap.getOrElseUpdate(shift, { - val pvCount = pvm.values.map(m => m.size).sum - (constructor(pvCount), Unknown) - }) - val updatedStatus = (previousStatus, status) match { - case (Unknown, s) => s - case (s, Unknown) => s - case (Known, Known) => Known - case t => throw new Exception(s"did not expect $t") - } - shiftMap(shift) = (pv, updatedStatus) - pv - } - def makePat(expr: rc.Expr, bound: Expr.Bound, isRhs: Boolean, @@ -502,12 +514,7 @@ object NamedRewrite { case ((s, _, _, _, _), (pv, Known)) => (s, pv) }.get val nfIndex = iS - nfShift // >= 0 because iS >= nfShift - (a: Applier) => (new ConditionalApplier(Set(iPV), (Set(FreeAnalysis), Set()), acc(a)) { - def cond(egraph: EGraph, eclass: EClassId, shc: Substs)(subst: shc.Subst): Boolean = { - val freeOf = egraph.getAnalysis(FreeAnalysis) - !freeOf(shc.get(iPV, subst)).free.contains(nfIndex) - } - }) + (a: Applier) => NotFreeInApplier(iPV, nfIndex, acc(a)) case VectorizeScalarFun(f, n, fV) => val (nPV, nST) = natPatVars(n)(0, 0) assert(nST == Known) diff --git a/src/main/scala/rise/eqsat/Reggvolution.scala b/src/main/scala/rise/eqsat/Reggvolution.scala index 392485c0b..d8707b8b4 100644 --- a/src/main/scala/rise/eqsat/Reggvolution.scala +++ b/src/main/scala/rise/eqsat/Reggvolution.scala @@ -26,21 +26,417 @@ object Reggvolution { // NOTE: could define reggvolution for generic nodes, but this is simpler reggvolve(Pattern.fromExpr(expr)) - def reggvolve(rw: Rewrite): String = { - val lhs = rw.searcher match { - case cp: CompiledPatternSearcher => reggvolve(cp.pat) - case _ => throw new Exception(s"could not reggvolve searcher: ${rw.searcher.getClass()}") + // same as NamedRewrite.init, but flattens DeBruijn indices from different kinds. + def reggvolveNamedRewrite( + name: String, + rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), + parameters: Seq[NamedRewrite.Parameter] = Seq(), + ): String = { + import rise.core.DSL.infer + import arithexpr.{arithmetic => ae} + import rise.eqsat.NamedRewrite._ + import rise.core.types.DataKind.IDWrapper + import rise.{core => rc} + import rise.core.{types => rct} + import rise.core.types.{DataType => rcdt} + + val (typedLhs, freeV, freeT, typedRhs) = typeRule(rule, parameters) + + type FlatShift = Int + val patVars: PatternVarMap[FlatShift, PatternVar] = HashMap() + + // nats which we need to pivot to avoid matching over certain nat constructs + val natsToPivot = Vec[(rct.Nat, rct.NatIdentifier, FlatShift, NatPatternVar)]() + + val boundVarToShift = HashMap[String, FlatShift]() + + def makeExprPatVar( + name: String, + shift: FlatShift, + status: PatVarStatus + ): PatternVar = + makePatVar(name, shift, patVars, PatternVar, status) + + def makeOtherPatVar[V]( + name: String, + shift: FlatShift, + constructor: Int => V, + status: PatVarStatus, + ): V = + makeExprPatVar(name, shift, status) match { + case PatternVar(index) => constructor(index) + } + + def shiftOfBound(bound: Expr.Bound): FlatShift = + // (bound.expr.size, bound.nat.size, bound.data.size, bound.addr.size, bound.n2n.size) + bound.expr.size + bound.nat.size + bound.data.size + bound.addr.size + bound.n2n.size + + def makePat(expr: rc.Expr, + bound: Expr.Bound, + isRhs: Boolean, + matchType: Boolean = true): Pattern = + Pattern(expr match { + case i: rc.Identifier if freeV.contains(i.name) => + makeExprPatVar(i.name, shiftOfBound(bound), if (isRhs) { Unknown } else { Known }) + case i: rc.Identifier => PatternNode(Var(bound.indexOf(i))) + + // note: we do not match for the type of lambda bodies, as we can always infer it: + // lam(x : xt, e : et) : xt -> et + case rc.Lambda(x, e) => + // right now we assume that all bound variables are uniquely named + if (!isRhs) { + assert(!boundVarToShift.contains(x.name)) + boundVarToShift += x.name -> shiftOfBound(bound) + } + PatternNode(Lambda(makePat(e, bound + x, isRhs, matchType = false))) + case rc.DepLambda(rct.NatKind, x: rct.NatIdentifier, e) => + PatternNode(NatLambda(makePat(e, bound + x, isRhs, matchType = false))) + case rc.DepLambda(rct.DataKind, x: rcdt.DataTypeIdentifier, e) => + PatternNode(DataLambda(makePat(e, bound + x, isRhs, matchType = false))) + case rc.DepLambda(rct.AddressSpaceKind, x: rct.AddressSpaceIdentifier, e) => + PatternNode(AddrLambda(makePat(e, bound + x, isRhs, matchType = false))) + case rc.DepLambda(_, _, _) => ??? + + case rc.App(rc.App(NamedRewriteDSL.Composition(_), f), g) => + PatternNode(Composition( + makePat(f, bound, isRhs, matchType = true), + makePat(g, bound, isRhs, matchType = false))) + + // note: we do not match for the type of applied functions, as we can always infer it: + // app(f : et -> at, e : et) : at + case rc.App(f, e) => + PatternNode(App(makePat(f, bound, isRhs, matchType = false), makePat(e, bound, isRhs))) + case rc.DepApp(rct.NatKind, f, x: rct.Nat) => + PatternNode(NatApp( + makePat(f, bound, isRhs, matchType = false), makeNPat(x, bound, isRhs))) + case rc.DepApp(rct.DataKind, f, x: rct.DataType) => + PatternNode(DataApp( + makePat(f, bound, isRhs, matchType = false), makeDTPat(x, bound, isRhs))) + case rc.DepApp(rct.AddressSpaceKind, f, x: rct.AddressSpace) => + PatternNode(AddrApp( + makePat(f, bound, isRhs, matchType = false), makeAPat(x, bound, isRhs))) + case rc.DepApp(_, _, _) => ??? + + case rc.Literal(rc.semantics.NatData(n)) => + PatternNode(NatLiteral(makeNPat(n, bound, isRhs))) + case rc.Literal(rc.semantics.IndexData(i, n)) => + PatternNode(IndexLiteral(makeNPat(i, bound, isRhs), makeNPat(n, bound, isRhs))) + case rc.Literal(d) => PatternNode(Literal(d)) + // note: we set the primitive type to a place holder here, + // because we do not want type information at the node level + case p: rc.Primitive => PatternNode(Primitive(p.setType(rct.TypePlaceholder))) + + case _ => ??? + }, if (!isRhs && !matchType) TypePatternAny else makeTPat(expr.t, bound, isRhs)) + + def makeNPat(n: rct.Nat, bound: Expr.Bound, isRhs: Boolean): NatPattern = + n match { + case i: rct.NatIdentifier if freeT(rct.NatKind.IDWrapper(i)) => + makeOtherPatVar(i.name, shiftOfBound(bound), + NatPatternVar, if (isRhs) { Unknown } else { Known }) + case i: rct.NatIdentifier => + NatPatternNode(NatVar(bound.indexOf(i))) + case ae.Cst(c) => + NatPatternNode(NatCst(c)) + case ae.Sum(Nil) => NatPatternNode(NatCst(0)) + case ae.Sum(t +: ts) if isRhs => ts.foldRight(makeNPat(t, bound, isRhs)) { case (t, acc) => + NatPatternNode(NatAdd(makeNPat(t, bound, isRhs), acc)) + } + case ae.Prod(Nil) => NatPatternNode(NatCst(1)) + case ae.Prod(t +: ts) if isRhs => ts.foldRight(makeNPat(t, bound, isRhs)) { case (t, acc) => + NatPatternNode(NatMul(makeNPat(t, bound, isRhs), acc)) + } + case ae.Pow(b, e) if isRhs => + NatPatternNode(NatPow(makeNPat(b, bound, isRhs), makeNPat(e, bound, isRhs))) + // do not match over these nat constructs on the left-hand side, + // as structural matching would not be sufficient, + // try to pivot the equality around a fresh pattern variable instead + case ae.Sum(_) | ae.Prod(_) | ae.Pow(_, _) if !isRhs => + val nv = rct.NatIdentifier(s"_nv${natsToPivot.size}") + val shift = shiftOfBound(bound) + val pv = makeOtherPatVar(nv.name, shift, NatPatternVar, Known) + natsToPivot.addOne((n, nv, shift, pv)) + pv + case _ => + throw new Exception(s"did not expect $n") + } + + def makeDTPat(dt: rct.DataType, bound: Expr.Bound, isRhs: Boolean): DataTypePattern = + dt match { + case i: rcdt.DataTypeIdentifier if freeT(IDWrapper(i)) => + makeOtherPatVar(i.name, shiftOfBound(bound), + DataTypePatternVar, if (isRhs) { Unknown } else { Known }) + case i: rcdt.DataTypeIdentifier => + DataTypePatternNode(DataTypeVar(bound.indexOf(i))) + case s: rcdt.ScalarType => + DataTypePatternNode(ScalarType(s)) + case rcdt.NatType => + DataTypePatternNode(NatType) + case rcdt.VectorType(s, et) => + DataTypePatternNode(VectorType(makeNPat(s, bound, isRhs), makeDTPat(et, bound, isRhs))) + case rcdt.IndexType(s) => + DataTypePatternNode(IndexType(makeNPat(s, bound, isRhs))) + case rcdt.PairType(dt1, dt2) => + DataTypePatternNode(PairType(makeDTPat(dt1, bound, isRhs), makeDTPat(dt2, bound, isRhs))) + case rcdt.ArrayType(s, et) => + DataTypePatternNode(ArrayType(makeNPat(s, bound, isRhs), makeDTPat(et, bound, isRhs))) + case _: rcdt.DepArrayType | _: rcdt.DepPairType[_, _] | + _: rcdt.NatToDataApply | _: rcdt.FragmentType | rcdt.ManagedBufferType(_) | rcdt.OpaqueType(_) => + throw new Exception(s"did not expect $dt") + } + + def makeTPat(t: rct.ExprType, bound: Expr.Bound, isRhs: Boolean): TypePattern = + t match { + case dt: rct.DataType => makeDTPat(dt, bound, isRhs) + case rct.FunType(a, b) => + TypePatternNode(FunType(makeTPat(a, bound, isRhs), makeTPat(b, bound, isRhs))) + case rct.DepFunType(rct.NatKind, x: rct.NatIdentifier, t) => + TypePatternNode(NatFunType(makeTPat(t, bound + x, isRhs))) + case rct.DepFunType(rct.DataKind, x: rcdt.DataTypeIdentifier, t) => + TypePatternNode(DataFunType(makeTPat(t, bound + x, isRhs))) + case rct.DepFunType(rct.AddressSpaceKind, x: rct.AddressSpaceIdentifier, t) => + TypePatternNode(AddrFunType(makeTPat(t, bound + x, isRhs))) + case rct.DepFunType(_, _, _) => ??? + case i: rct.TypeIdentifier => + assert(freeT(rct.TypeKind.IDWrapper(i))) + makeOtherPatVar(i.name, shiftOfBound(bound), + TypePatternVar, if (isRhs) { Unknown } else { Known }) + case rct.TypePlaceholder => + throw new Exception(s"did not expect $t, something was not infered") + } + + def makeAPat(a: rct.AddressSpace, bound: Expr.Bound, isRhs: Boolean): AddressPattern = + a match { + case i: rct.AddressSpaceIdentifier if freeT(rct.AddressSpaceKind.IDWrapper(i)) => + makeOtherPatVar(i.name, shiftOfBound(bound), + AddressPatternVar, if (isRhs) { Unknown } else { Known }) + case i: rct.AddressSpaceIdentifier => + AddressPatternNode(AddressVar(bound.indexOf(i))) + case rct.AddressSpace.Global => AddressPatternNode(Global) + case rct.AddressSpace.Local => AddressPatternNode(Local) + case rct.AddressSpace.Private => AddressPatternNode(Private) + case rct.AddressSpace.Constant => AddressPatternNode(Constant) + } + + val lhsPat = makePat(typedLhs, Expr.Bound.empty, isRhs = false) + val rhsPat = makePat(typedRhs, Expr.Bound.empty, isRhs = true) + + def shiftAppliers[S, V](pvm: PatternVarMap[S, V], + mkShift: (S, V) => (S, V) => String => String, + mkShiftCheck: (S, V) => (S, V) => String => String, + ): String => String = { + pvm.foldRight { a: String => a } { case ((name, shiftMap), acc) => + shiftMap.collectFirst { case (s, (v, ShiftCoherent)) => (s, v) } + // if nothing is shift coherent yet, pick any known shift as our reference + .orElse(shiftMap.collectFirst { case (s, (v, Known)) => (s, v) }) match { + case Some(base) => + shiftMap(base._1) = (base._2, ShiftCoherent) + + shiftMap.foldRight(acc) { case ((shift, (pv, status)), acc) => + status match { + // nothing to do + case ShiftCoherent => acc + // check a shifted variable + case Known => + shiftMap(shift) = (pv, ShiftCoherent) + a: String => acc(mkShiftCheck.tupled(base)(shift, pv)(a)) + // construct a shifted variable + case Unknown => + shiftMap(shift) = (pv, ShiftCoherent) + a: String => acc(mkShift.tupled(base)(shift, pv)(a)) + } + } + // nothing is known, but it may become known later (e.g. after nat pivoting) + case None => acc + } + } + } + + def patMkShift(s1: FlatShift, pv1: PatternVar) + (s2: FlatShift, pv2: PatternVar) + (applier: String): String = { + assert(s1 != s2) + val cutoff = s1 + val shift = s2 - s1 + s"""{ shifted("${pv1}", "${pv2}", ${shift}, ${cutoff}, ${applier}) }""" + } + + def patMkShiftCheck(s1: FlatShift, pv1: PatternVar) + (s2: FlatShift, pv2: PatternVar) + (applier: String): String = { + assert(s1 != s2) + val cutoff = s1 + val shift = s2 - s1 + s"""{ shifted_check("${pv1}", "${pv2}", ${shift}, ${cutoff}, ${applier}) }""" } - val rhs = rw.applier match { - case cp: PatternApplier => reggvolve(cp.pattern) - case _ => throw new Exception(s"could not reggvolve applier: ${rw.applier.getClass()}") + + // FIXME: duplicated from type inference's 'pivotSolution' + @scala.annotation.tailrec + def tryPivot(pivot: rct.NatIdentifier, n: rct.Nat, value: rct.Nat): Option[rct.Nat] = { + import arithexpr.arithmetic._ + + n match { + case i: rct.NatIdentifier if i == pivot => Some(value) + case Prod(terms) => + val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) + if (p.size != 1) { + None + } else { + tryPivot(pivot, p.head, rest.foldLeft(value)({ + case (v, r) => v /^ r + })) + } + case Sum(terms) => + val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) + if (p.size != 1) { + None + } else { + tryPivot(pivot, p.head, rest.foldLeft(value)({ + case (v, r) => v - r + })) + } + case Pow(b, Cst(-1)) => tryPivot(pivot, b, Cst(1) /^ value) + case Mod(p, m) if p == pivot => + val k = rct.NatIdentifier(s"_k_${p}_${m}", RangeAdd(0, PosInf, 1)) + Some(k*m + value) + case _ => None + } } - s"""rewrite!("${rw.name}"; "${lhs}" => "${rhs}")""" + + def pivotNatsRec(natsToPivot: Seq[(rct.Nat, rct.NatIdentifier, FlatShift, NatPatternVar)], + couldNotPivot: Seq[(rct.Nat, rct.NatIdentifier, FlatShift, NatPatternVar)]) + (applier: String): String = { + import arithexpr.arithmetic._ + + def pivotSuccess = pivotNatsRec(natsToPivot.tail ++ couldNotPivot, Seq())(applier) + def pivotFailure = pivotNatsRec(natsToPivot.tail, couldNotPivot :+ natsToPivot.head)(applier) + + natsToPivot.headOption match { + case Some((n, nv, shift, pv)) => + def fromNamed(n: rct.Nat): NatPattern = { + n match { + case i: rct.NatIdentifier => + makeOtherPatVar(i.name, shift, NatPatternVar, Unknown) + case PosInf => NatPatternNode(NatPosInf) + case NegInf => NatPatternNode(NatNegInf) + case Cst(c) => NatPatternNode(NatCst(c)) + case Sum(Nil) => NatPatternNode(NatCst(0)) + case Sum(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => + NatPatternNode(NatAdd(fromNamed(t), acc)) + } + case Prod(Nil) => NatPatternNode(NatCst(1)) + case Prod(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => + NatPatternNode(NatMul(fromNamed(t), acc)) + } + case Pow(b, e) => + NatPatternNode(NatPow(fromNamed(b), fromNamed(e))) + case Mod(a, b) => + NatPatternNode(NatMod(fromNamed(a), fromNamed(b))) + case IntDiv(a, b) => + NatPatternNode(NatIntDiv(fromNamed(a), fromNamed(b))) + case _ => throw new Exception(s"no support for $n") + } + } + + val natsToFindOut = HashMap[rct.NatIdentifier, Integer]().withDefault(_ => 0) + ArithExpr.visit(n, { + case ni: rct.NatIdentifier => + val isKnown = patVars.get(ni.name) + .exists(shiftMap => shiftMap.exists { case (s, (pv, status)) => status != Unknown }) + if (!isKnown) { + natsToFindOut(ni) += 1 + } + case _ => + }) + natsToFindOut.size match { + case 0 => // check nv = n + val valuePat = fromNamed(n) + val updateShifts = shiftAppliers(patVars, patMkShift, patMkShiftCheck) + val vp = reggvolve(valuePat) + updateShifts(s"""{ compute_nat_check("${pv}", "${vp}", ${pivotSuccess}) }""") + // ComputeNatCheckApplier(pv, valuePat, pivotSuccess)) + case 1 => + val (potentialPivot, uses) = natsToFindOut.head + if (uses == 1) { + tryPivot(potentialPivot, n, nv) match { + case Some(value) => + val valuePat = fromNamed(value) + val updateShifts = shiftAppliers(patVars, patMkShift, patMkShiftCheck) + val pivotPat = makeOtherPatVar(potentialPivot.name, shift, + NatPatternVar, Known) + val vp = reggvolve(valuePat) + val a = shiftAppliers(patVars, patMkShift, patMkShiftCheck)(pivotSuccess) + updateShifts(s"""{ compute_nat("${pivotPat}", "${vp}", ${a}) }""") + // ComputeNatApplier(pivotPat, valuePat, a) + case None => pivotFailure + } + } else { + pivotFailure + } + case _ => pivotFailure + } + case None => + if (couldNotPivot.nonEmpty) { + throw new Exception(s"could not pivot nats: $couldNotPivot") + } else { + applier + } + } + } + + val searcher: String = s""""${reggvolve(lhsPat)}"""" + val param = parameters.foldRight((a: String) => a) { case (c, acc) => + c match { + case NotFreeIn(notFree, in) => + val nfShift = boundVarToShift.getOrElse(notFree, 0) + // all left-hand-side uses of `in` may contain `notFree` + assert(patVars(in).forall { + case (shift, (_, status)) => + shift >= nfShift || status != Known + }) + // pick one of these uses + val (iS, iPV) = patVars(in).collectFirst { + case (s, (pv, Known)) => (s, pv) + }.get + val nfIndex = iS - nfShift // >= 0 because iS >= nfShift + (a: String) => s"""{ not_free_in("${iPV}", ${nfIndex}, ${a}) }""" + // NotFreeInApplier(iPV, nfIndex, acc(a)) + case VectorizeScalarFun(f, n, fV) => + val (enPV, nST) = patVars(n)(0) + val nPV = enPV match { + case PatternVar(index) => NatPatternVar(index) + } + assert(nST == Known) + val (fPV, fST) = patVars(f)(0) + assert(fST == Known) + val fVPV = makePatVar(fV, 0, patVars, PatternVar, Known) + (a: String) => s"""{ vectorize_scalar_fun("${fPV}", ${nPV}, "${fVPV}", ${a}) }""" + // VectorizeScalarFunExtractApplier(fPV, nPV, fVPV, acc(a)) + } + } + val shiftPV = shiftAppliers(patVars, patMkShift, patMkShiftCheck) + val pivotNats = pivotNatsRec(natsToPivot.toSeq, Seq()) _ + val rhsPatApplier = s""""${reggvolve(rhsPat)}"""" + val applier = param(shiftPV(pivotNats(rhsPatApplier))) + + def allIsShiftCoherent[S, V](pvm: PatternVarMap[S, V]): Boolean = + pvm.forall { case (_, shiftMap) => + shiftMap.forall { case (_, (_, status)) => status == ShiftCoherent }} + assert(allIsShiftCoherent(patVars)) + + s"""rewrite!("${name}"; ${searcher} => ${applier})""" } + // DEPRECATED: + // def reggvolve(searcher: Searcher): String = + // def reggvolve(applier: Applier): String = + def reggvolve(pat: Pattern): String = reggvolve(pat, (0, 0, 0, 0, 0)) + def reggvolve(pat: NatPattern): String = + reggvolve(pat, (0, 0, 0, 0, 0)) + def reggvolve(pat: Pattern, s: Shift): String = { val e = pat.p match { case PatternVar(index) => s"?e${index}" diff --git a/src/main/scala/rise/eqsat/Rewrite.scala b/src/main/scala/rise/eqsat/Rewrite.scala index 05e7b391f..b658e09f5 100644 --- a/src/main/scala/rise/eqsat/Rewrite.scala +++ b/src/main/scala/rise/eqsat/Rewrite.scala @@ -133,6 +133,15 @@ abstract class ConditionalApplier(condPatternVars: Set[Any], } } +case class NotFreeInApplier( + iPV: PatternVar, nfIndex: Int, applier: Applier +) extends ConditionalApplier(Set(iPV), (Set(FreeAnalysis), Set()), applier) { + override def cond(egraph: EGraph, eclass: EClassId, shc: Substs)(subst: shc.Subst): Boolean = { + val freeOf = egraph.getAnalysis(FreeAnalysis) + !freeOf(shc.get(iPV, subst)).free.contains(nfIndex) + } +} + /** An [[Applier]] that shifts the DeBruijn indices of a variable */ case class ShiftedApplier(v: PatternVar, newV: PatternVar, shift: Expr.Shift, cutoff: Expr.Shift, @@ -155,10 +164,11 @@ case class ShiftedApplier(v: PatternVar, newV: PatternVar, /** An [[Applier]] that shifts the DeBruijn indices of a variable. * @note It works by extracting an expression from the [[EGraph]] in order to shift it. */ -case class ShiftedExtractApplier(v: PatternVar, newV: PatternVar, - shift: Expr.Shift, cutoff: Expr.Shift, - applier: Applier) - extends Applier { +case class ShiftedExtractApplier( + v: PatternVar, newV: PatternVar, + shift: Expr.Shift, cutoff: Expr.Shift, + applier: Applier +) extends Applier { override def patternVars(): Set[Any] = applier.patternVars() - newV + v @@ -180,26 +190,26 @@ case class ShiftedExtractApplier(v: PatternVar, newV: PatternVar, /** An [[Applier]] that checks whether a shifted variable is equal to another * @note It works by extracting an expression from the [[EGraph]] in order to shift it. */ -object ShiftedCheckApplier { - def apply(v: PatternVar, v2: PatternVar, - shift: Expr.Shift, cutoff: Expr.Shift, - applier: Applier): Applier = - new ConditionalApplier(Set(v, v2), (Set(SmallestSizeAnalysis), Set()), applier) { - override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { - val smallestOf = egraph.getAnalysis(SmallestSizeAnalysis) - val extract = smallestOf(substs.get(v, subst))._1 - val shifted = extract.shifted(egraph, shift, cutoff) - val expected = smallestOf(substs.get(v2, subst))._1 - shifted == expected - } - } +case class ShiftedCheckApplier( + v: PatternVar, v2: PatternVar, + shift: Expr.Shift, cutoff: Expr.Shift, + applier: Applier +) extends ConditionalApplier(Set(v, v2), (Set(SmallestSizeAnalysis), Set()), applier) { + override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { + val smallestOf = egraph.getAnalysis(SmallestSizeAnalysis) + val extract = smallestOf(substs.get(v, subst))._1 + val shifted = extract.shifted(egraph, shift, cutoff) + val expected = smallestOf(substs.get(v2, subst))._1 + shifted == expected + } } /** An [[Applier]] that shifts the DeBruijn indices of a nat variable */ -case class ShiftedNatApplier(v: NatPatternVar, newV: NatPatternVar, - shift: Nat.Shift, cutoff: Nat.Shift, - applier: Applier) - extends Applier { +case class ShiftedNatApplier( + v: NatPatternVar, newV: NatPatternVar, + shift: Nat.Shift, cutoff: Nat.Shift, + applier: Applier +) extends Applier { override def patternVars(): Set[Any] = applier.patternVars() - newV + v @@ -218,25 +228,25 @@ case class ShiftedNatApplier(v: NatPatternVar, newV: NatPatternVar, } /** An [[Applier]] that checks whether a shifted nat variable is equal to another */ -object ShiftedNatCheckApplier { - def apply(v: NatPatternVar, v2: NatPatternVar, - shift: Nat.Shift, cutoff: Nat.Shift, - applier: Applier): Applier = - new ConditionalApplier(Set(v, v2), (Set(), Set()), applier) { - override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { - val nat = substs.get(v, subst) - val shifted = NodeSubs.Nat.shifted(egraph, nat, shift, cutoff) - val expected = substs.get(v2, subst) - shifted == expected - } - } +case class ShiftedNatCheckApplier( + v: NatPatternVar, v2: NatPatternVar, + shift: Nat.Shift, cutoff: Nat.Shift, + applier: Applier +) extends ConditionalApplier(Set(v, v2), (Set(), Set()), applier) { + override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { + val nat = substs.get(v, subst) + val shifted = NodeSubs.Nat.shifted(egraph, nat, shift, cutoff) + val expected = substs.get(v2, subst) + shifted == expected + } } /** An [[Applier]] that shifts the DeBruijn indices of a data type variable */ -case class ShiftedDataTypeApplier(v: DataTypePatternVar, newV: DataTypePatternVar, - shift: Type.Shift, cutoff: Type.Shift, - applier: Applier) - extends Applier { +case class ShiftedDataTypeApplier( + v: DataTypePatternVar, newV: DataTypePatternVar, + shift: Type.Shift, cutoff: Type.Shift, + applier: Applier +) extends Applier { override def patternVars(): Set[Any] = applier.patternVars() - newV + v @@ -255,25 +265,25 @@ case class ShiftedDataTypeApplier(v: DataTypePatternVar, newV: DataTypePatternVa } /** An [[Applier]] that checks whether a shifted nat variable is equal to another */ -object ShiftedDataTypeCheckApplier { - def apply(v: DataTypePatternVar, v2: DataTypePatternVar, - shift: Type.Shift, cutoff: Type.Shift, - applier: Applier): Applier = - new ConditionalApplier(Set(v, v2), (Set(), Set()), applier) { - override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { - val dt = substs.get(v, subst) - val shifted = NodeSubs.DataType.shifted(egraph, dt, shift, cutoff) - val expected = substs.get(v2, subst) - shifted == expected - } - } +case class ShiftedDataTypeCheckApplier( + v: DataTypePatternVar, v2: DataTypePatternVar, + shift: Type.Shift, cutoff: Type.Shift, + applier: Applier +) extends ConditionalApplier(Set(v, v2), (Set(), Set()), applier) { + override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { + val dt = substs.get(v, subst) + val shifted = NodeSubs.DataType.shifted(egraph, dt, shift, cutoff) + val expected = substs.get(v2, subst) + shifted == expected + } } /** An [[Applier]] that shifts the DeBruijn indices of a type variable */ -case class ShiftedTypeApplier(v: TypePatternVar, newV: TypePatternVar, - shift: Type.Shift, cutoff: Type.Shift, - applier: Applier) - extends Applier { +case class ShiftedTypeApplier( + v: TypePatternVar, newV: TypePatternVar, + shift: Type.Shift, cutoff: Type.Shift, + applier: Applier +) extends Applier { override def patternVars(): Set[Any] = applier.patternVars() - newV + v @@ -292,18 +302,17 @@ case class ShiftedTypeApplier(v: TypePatternVar, newV: TypePatternVar, } /** An [[Applier]] that checks whether a shifted nat variable is equal to another */ -object ShiftedTypeCheckApplier { - def apply(v: TypePatternVar, v2: TypePatternVar, - shift: Type.Shift, cutoff: Type.Shift, - applier: Applier): Applier = - new ConditionalApplier(Set(v, v2), (Set(), Set()), applier) { - override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { - val t = substs.get(v, subst) - val shifted = NodeSubs.Type.shifted(egraph, t, shift, cutoff) - val expected = substs.get(v2, subst) - shifted == expected - } - } +case class ShiftedTypeCheckApplier( + v: TypePatternVar, v2: TypePatternVar, + shift: Type.Shift, cutoff: Type.Shift, + applier: Applier +) extends ConditionalApplier(Set(v, v2), (Set(), Set()), applier) { + override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { + val t = substs.get(v, subst) + val shifted = NodeSubs.Type.shifted(egraph, t, shift, cutoff) + val expected = substs.get(v2, subst) + shifted == expected + } } /** An [[Applier]] that performs beta-reduction. @@ -389,21 +398,21 @@ case class BetaNatExtractApplier(body: PatternVar, subs: NatPatternVar) } /** An [[Applier]] that checks whether a nat variable is equal to a nat pattern */ -object ComputeNatCheckApplier { - def apply(v: NatPatternVar, expected: NatPattern, - applier: Applier): Applier = - new ConditionalApplier(expected.patternVars() + v, (Set(), Set()), applier) { - override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { - // TODO: can we be more efficient here? - ComputeNat.toNamed(egraph, v, substs)(subst) == ComputeNat.toNamed(egraph, expected, substs)(subst) - } - } +case class ComputeNatCheckApplier( + v: NatPatternVar, expected: NatPattern, + applier: Applier +) extends ConditionalApplier(expected.patternVars() + v, (Set(), Set()), applier) { + override def cond(egraph: EGraph, id: EClassId, substs: Substs)(subst: substs.Subst): Boolean = { + // TODO: can we be more efficient here? + ComputeNat.toNamed(egraph, v, substs)(subst) == ComputeNat.toNamed(egraph, expected, substs)(subst) + } } /** An [[Applier]] that computes a nat variable according to a nat pattern */ -case class ComputeNatApplier(v: NatPatternVar, value: NatPattern, - applier: Applier) extends Applier { - +case class ComputeNatApplier( + v: NatPatternVar, value: NatPattern, + applier: Applier +) extends Applier { override def patternVars(): Set[Any] = applier.patternVars() - v ++ value.patternVars() @@ -502,9 +511,10 @@ private object ComputeNat { /** An [[Applier]] that vectorizes a scalar function. * @note It works by extracting an expression from the [[EGraph]] in order to vectorize it. */ -case class VectorizeScalarFunExtractApplier(f: PatternVar, n: NatPatternVar, fV: PatternVar, - applier: Applier) - extends Applier { +case class VectorizeScalarFunExtractApplier( + f: PatternVar, n: NatPatternVar, fV: PatternVar, + applier: Applier +) extends Applier { override def patternVars(): Set[Any] = applier.patternVars() - fV override def requiredAnalyses(): (Set[Analysis], Set[TypeAnalysis]) = From bea11a5b53dc2a1d4e1fb7f981c4053a9fec60f2 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Tue, 30 Sep 2025 17:44:23 +0200 Subject: [PATCH 6/7] factorize and fix code --- src/main/scala/rise/eqsat/NamedRewrite.scala | 416 +++++++++++-------- src/main/scala/rise/eqsat/Reggvolution.scala | 368 ++-------------- 2 files changed, 288 insertions(+), 496 deletions(-) diff --git a/src/main/scala/rise/eqsat/NamedRewrite.scala b/src/main/scala/rise/eqsat/NamedRewrite.scala index 27245f918..b7e60020d 100644 --- a/src/main/scala/rise/eqsat/NamedRewrite.scala +++ b/src/main/scala/rise/eqsat/NamedRewrite.scala @@ -72,6 +72,8 @@ object NamedRewrite { // from var name to var index and a status depending on local index shift type PatternVarMap[S, V] = HashMap[String, HashMap[S, (V, PatVarStatus)]] + // take a global named pattern variable and make or retrieve an index + // pattern variable that is local to the surrounding DeBruijn index shift context. def makePatVar[S, V]( name: String, shift: S, @@ -94,33 +96,38 @@ object NamedRewrite { pv } - def init(name: String, - rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), - parameters: Seq[NamedRewrite.Parameter] = Seq(), - ): Rewrite = { + // take a named pattern (lhs or rhs of rule) and transform it into an + // index-based pattern, updating the pattern variable maps on the way. + // also updates which bounds variables need a certain shift to be available, + // and which nats to pivot to avoid matching over certain nat constructs. + def makePat[S, NS, TS, AS]( + expr: rc.Expr, + bound: Expr.Bound, + isRhs: Boolean, + freeV: Map[String, rct.ExprType], + freeT: Set[rct.Kind.Identifier], + shiftOfBound: Expr.Bound => S, + natShiftOfBound: Expr.Bound => NS, + typeShiftOfBound: Expr.Bound => TS, + addrShiftOfBound: Expr.Bound => AS, + patVars: PatternVarMap[S, PatternVar], + natPatVars: PatternVarMap[NS, NatPatternVar], + dataTypePatVars: PatternVarMap[TS, DataTypePatternVar], + typePatVars: PatternVarMap[TS, TypePatternVar], + addrPatVars: PatternVarMap[AS, AddressPatternVar], + natsToPivot: Vec[(rct.Nat, rct.NatIdentifier, NS, NatPatternVar)], + boundVarToShift: HashMap[String, S], + // matchType: Boolean = true + ): Pattern = { import arithexpr.{arithmetic => ae} - val (typedLhs, freeV, freeT, typedRhs) = typeRule(rule, parameters) - - val patVars: PatternVarMap[Expr.Shift, PatternVar] = HashMap() - val natPatVars: PatternVarMap[Nat.Shift, NatPatternVar] = HashMap() - val dataTypePatVars: PatternVarMap[Type.Shift, DataTypePatternVar] = HashMap() - val typePatVars: PatternVarMap[Type.Shift, TypePatternVar] = HashMap() - val addrPatVars: PatternVarMap[Address.Shift, AddressPatternVar] = HashMap() - - // nats which we need to pivot to avoid matching over certain nat constructs - val natsToPivot = Vec[(rct.Nat, rct.NatIdentifier, Nat.Shift, NatPatternVar)]() - - val boundVarToShift = HashMap[String, Expr.Shift]() - def makePat(expr: rc.Expr, bound: Expr.Bound, isRhs: Boolean, matchType: Boolean = true): Pattern = Pattern(expr match { case i: rc.Identifier if freeV.contains(i.name) => - makePatVar(i.name, - (bound.expr.size, bound.nat.size, bound.data.size, bound.addr.size, bound.n2n.size), + makePatVar(i.name, shiftOfBound(bound), patVars, PatternVar, if (isRhs) { Unknown } else { Known }) case i: rc.Identifier => PatternNode(Var(bound.indexOf(i))) @@ -128,12 +135,12 @@ object NamedRewrite { // lam(x : xt, e : et) : xt -> et case rc.Lambda(x, e) => // right now we assume that all bound variables are uniquely named + val newBound = bound + x if (!isRhs) { assert(!boundVarToShift.contains(x.name)) - boundVarToShift += x.name -> - (bound.expr.size + 1, bound.nat.size, bound.data.size, bound.addr.size, bound.n2n.size) + boundVarToShift += x.name -> shiftOfBound(newBound) } - PatternNode(Lambda(makePat(e, bound + x, isRhs, matchType = false))) + PatternNode(Lambda(makePat(e, newBound, isRhs, matchType = false))) case rc.DepLambda(rct.NatKind, x: rct.NatIdentifier, e) => PatternNode(NatLambda(makePat(e, bound + x, isRhs, matchType = false))) case rc.DepLambda(rct.DataKind, x: rcdt.DataTypeIdentifier, e) => @@ -177,7 +184,7 @@ object NamedRewrite { def makeNPat(n: rct.Nat, bound: Expr.Bound, isRhs: Boolean): NatPattern = n match { case i: rct.NatIdentifier if freeT(rct.NatKind.IDWrapper(i)) => - makePatVar(i.name, (bound.nat.size, bound.n2n.size), natPatVars, + makePatVar(i.name, natShiftOfBound(bound), natPatVars, NatPatternVar, if (isRhs) { Unknown } else { Known }) case i: rct.NatIdentifier => NatPatternNode(NatVar(bound.indexOf(i))) @@ -198,8 +205,9 @@ object NamedRewrite { // try to pivot the equality around a fresh pattern variable instead case ae.Sum(_) | ae.Prod(_) | ae.Pow(_, _) if !isRhs => val nv = rct.NatIdentifier(s"_nv${natsToPivot.size}") - val pv = makePatVar(nv.name, (bound.nat.size, bound.n2n.size), natPatVars, NatPatternVar, Known) - natsToPivot.addOne((n, nv, (bound.nat.size, bound.n2n.size), pv)) + val shift = natShiftOfBound(bound) + val pv = makePatVar(nv.name, shift, natPatVars, NatPatternVar, Known) + natsToPivot.addOne((n, nv, shift, pv)) pv case _ => throw new Exception(s"did not expect $n") @@ -208,7 +216,7 @@ object NamedRewrite { def makeDTPat(dt: rct.DataType, bound: Expr.Bound, isRhs: Boolean): DataTypePattern = dt match { case i: rcdt.DataTypeIdentifier if freeT(IDWrapper(i)) => - makePatVar(i.name, (bound.nat.size, bound.data.size, bound.n2n.size), + makePatVar(i.name, typeShiftOfBound(bound), dataTypePatVars, DataTypePatternVar, if (isRhs) { Unknown } else { Known }) case i: rcdt.DataTypeIdentifier => DataTypePatternNode(DataTypeVar(bound.indexOf(i))) @@ -225,7 +233,7 @@ object NamedRewrite { case rcdt.ArrayType(s, et) => DataTypePatternNode(ArrayType(makeNPat(s, bound, isRhs), makeDTPat(et, bound, isRhs))) case _: rcdt.DepArrayType | _: rcdt.DepPairType[_, _] | - _: rcdt.NatToDataApply | _: rcdt.FragmentType | rcdt.ManagedBufferType(_) | rcdt.OpaqueType(_) => + _: rcdt.NatToDataApply | _: rcdt.FragmentType | rcdt.ManagedBufferType(_) | rcdt.OpaqueType(_) => throw new Exception(s"did not expect $dt") } @@ -243,8 +251,8 @@ object NamedRewrite { case rct.DepFunType(_, _, _) => ??? case i: rct.TypeIdentifier => assert(freeT(rct.TypeKind.IDWrapper(i))) - makePatVar(i.name, (bound.nat.size, bound.data.size, bound.n2n.size), - typePatVars, TypePatternVar, if (isRhs) { Unknown } else { Known }) + makePatVar(i.name, typeShiftOfBound(bound), typePatVars, + TypePatternVar, if (isRhs) { Unknown } else { Known }) case rct.TypePlaceholder => throw new Exception(s"did not expect $t, something was not infered") } @@ -252,7 +260,7 @@ object NamedRewrite { def makeAPat(a: rct.AddressSpace, bound: Expr.Bound, isRhs: Boolean): AddressPattern = a match { case i: rct.AddressSpaceIdentifier if freeT(rct.AddressSpaceKind.IDWrapper(i)) => - makePatVar(i.name, bound.addr.size, addrPatVars, + makePatVar(i.name, addrShiftOfBound(bound), addrPatVars, AddressPatternVar, if (isRhs) { Unknown } else { Known }) case i: rct.AddressSpaceIdentifier => AddressPatternNode(AddressVar(bound.indexOf(i))) @@ -262,40 +270,216 @@ object NamedRewrite { case rct.AddressSpace.Constant => AddressPatternNode(Constant) } - val lhsPat = makePat(typedLhs, Expr.Bound.empty, isRhs = false) - val rhsPat = makePat(typedRhs, Expr.Bound.empty, isRhs = true) - - def shiftAppliers[S, V](pvm: PatternVarMap[S, V], - mkShift: (S, V) => (S, V) => Applier => Applier, - mkShiftCheck: (S, V) => (S, V) => Applier => Applier, - ): Applier => Applier = { - pvm.foldRight { a: Applier => a } { case ((name, shiftMap), acc) => - shiftMap.collectFirst { case (s, (v, ShiftCoherent)) => (s, v) } - // if nothing is shift coherent yet, pick any known shift as our reference - .orElse(shiftMap.collectFirst { case (s, (v, Known)) => (s, v) }) match { - case Some(base) => - shiftMap(base._1) = (base._2, ShiftCoherent) - - shiftMap.foldRight(acc) { case ((shift, (pv, status)), acc) => - status match { - // nothing to do - case ShiftCoherent => acc - // check a shifted variable - case Known => - shiftMap(shift) = (pv, ShiftCoherent) - a: Applier => acc(mkShiftCheck.tupled(base)(shift, pv)(a)) - // construct a shifted variable - case Unknown => - shiftMap(shift) = (pv, ShiftCoherent) - a: Applier => acc(mkShift.tupled(base)(shift, pv)(a)) - } + makePat(expr, bound, isRhs) + } + + // take a pattern variable map, and apply shifts as necessary to either + // check equivalence to another known shift, or construct an unknown shift. + // this is necessary because the same named variable corresponds to multiple + // index-based variables used in different shift contexts. + def shiftAppliers[S, V, A]( + pvm: PatternVarMap[S, V], + mkShift: (S, V) => (S, V) => A => A, + mkShiftCheck: (S, V) => (S, V) => A => A, + ): A => A = { + pvm.foldRight { a: A => a } { case ((name, shiftMap), acc) => + shiftMap.collectFirst { case (s, (v, ShiftCoherent)) => (s, v) } + // if nothing is shift coherent yet, pick any known shift as our reference + .orElse(shiftMap.collectFirst { case (s, (v, Known)) => (s, v) }) match { + case Some(base) => + shiftMap(base._1) = (base._2, ShiftCoherent) + + shiftMap.foldRight(acc) { case ((shift, (pv, status)), acc) => + status match { + // nothing to do + case ShiftCoherent => acc + // check a shifted variable + case Known => + shiftMap(shift) = (pv, ShiftCoherent) + a: A => acc(mkShiftCheck.tupled(base)(shift, pv)(a)) + // construct a shifted variable + case Unknown => + shiftMap(shift) = (pv, ShiftCoherent) + a: A => acc(mkShift.tupled(base)(shift, pv)(a)) } - // nothing is known, but it may become known later (e.g. after nat pivoting) - case None => acc + } + // nothing is known, but it may become known later (e.g. after nat pivoting) + case None => acc + } + } + } + + // FIXME: duplicated from type inference's 'pivotSolution' + @scala.annotation.tailrec + def tryPivot( + pivot: rct.NatIdentifier, + n: rct.Nat, + value: rct.Nat, + ): Option[rct.Nat] = { + import arithexpr.arithmetic._ + + n match { + case i: rct.NatIdentifier if i == pivot => Some(value) + case Prod(terms) => + val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) + if (p.size != 1) { + None + } else { + tryPivot(pivot, p.head, rest.foldLeft(value)({ + case (v, r) => v /^ r + })) + } + case Sum(terms) => + val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) + if (p.size != 1) { + None + } else { + tryPivot(pivot, p.head, rest.foldLeft(value)({ + case (v, r) => v - r + })) } + case Pow(b, Cst(-1)) => tryPivot(pivot, b, Cst(1) /^ value) + case Mod(p, m) if p == pivot => + val k = rct.NatIdentifier(s"_k_${p}_${m}", RangeAdd(0, PosInf, 1)) + Some(k*m + value) + case _ => None + } + } + + // given nats to pivot in order to avoid matching over certain nat constructs, + // attempts to pivot them, failing otherwise. + def pivotNats[NS, A]( + natsToPivot: Seq[(rct.Nat, rct.NatIdentifier, NS, NatPatternVar)], + natPatVars: PatternVarMap[NS, NatPatternVar], + natPatMkShift: (NS, NatPatternVar) => (NS, NatPatternVar) => (A) => A, + natPatMkShiftCheck: (NS, NatPatternVar) => (NS, NatPatternVar) => (A) => A, + mkComputeNatCheck: (NatPatternVar, NatPattern, A) => A, + mkComputeNat: (NatPatternVar, NatPattern, A) => A, + ): A => A = { + def rec( + natsToPivot: Seq[(rct.Nat, rct.NatIdentifier, NS, NatPatternVar)], + couldNotPivot: Seq[(rct.Nat, rct.NatIdentifier, NS, NatPatternVar)]) + (applier: A + ): A = { + import arithexpr.arithmetic._ + + def pivotSuccess = rec(natsToPivot.tail ++ couldNotPivot, Seq())(applier) + def pivotFailure = rec(natsToPivot.tail, couldNotPivot :+ natsToPivot.head)(applier) + + natsToPivot.headOption match { + case Some((n, nv, shift, pv)) => + def fromNamed(n: rct.Nat): NatPattern = { + n match { + case i: rct.NatIdentifier => + makePatVar(i.name, shift, natPatVars, NatPatternVar, Unknown) + case PosInf => NatPatternNode(NatPosInf) + case NegInf => NatPatternNode(NatNegInf) + case Cst(c) => NatPatternNode(NatCst(c)) + case Sum(Nil) => NatPatternNode(NatCst(0)) + case Sum(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => + NatPatternNode(NatAdd(fromNamed(t), acc)) + } + case Prod(Nil) => NatPatternNode(NatCst(1)) + case Prod(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => + NatPatternNode(NatMul(fromNamed(t), acc)) + } + case Pow(b, e) => + NatPatternNode(NatPow(fromNamed(b), fromNamed(e))) + case Mod(a, b) => + NatPatternNode(NatMod(fromNamed(a), fromNamed(b))) + case IntDiv(a, b) => + NatPatternNode(NatIntDiv(fromNamed(a), fromNamed(b))) + case _ => throw new Exception(s"no support for $n") + } + } + + val natsToFindOut = HashMap[rct.NatIdentifier, Integer]().withDefault(_ => 0) + ArithExpr.visit(n, { + case ni: rct.NatIdentifier => + val isKnown = natPatVars.get(ni.name) + .exists(shiftMap => shiftMap.exists { case (s, (pv, status)) => status != Unknown }) + if (!isKnown) { + natsToFindOut(ni) += 1 + } + case _ => + }) + natsToFindOut.size match { + case 0 => // check nv = n + val valuePat = fromNamed(n) + val updateShifts = shiftAppliers(natPatVars, natPatMkShift, natPatMkShiftCheck) + updateShifts(mkComputeNatCheck(pv, valuePat, pivotSuccess)) + case 1 => + val (potentialPivot, uses) = natsToFindOut.head + if (uses == 1) { + tryPivot(potentialPivot, n, nv) match { + case Some(value) => + val valuePat = fromNamed(value) + val updateShifts = shiftAppliers(natPatVars, natPatMkShift, natPatMkShiftCheck) + val pivotPat = makePatVar(potentialPivot.name, shift, + natPatVars, NatPatternVar, Known) + updateShifts(mkComputeNat(pivotPat, valuePat, + shiftAppliers(natPatVars, natPatMkShift, natPatMkShiftCheck)(pivotSuccess))) + case None => pivotFailure + } + } else { + pivotFailure + } + case _ => pivotFailure + } + case None => + if (couldNotPivot.nonEmpty) { + throw new Exception(s"could not pivot nats: $couldNotPivot") + } else { + applier + } } } + rec(natsToPivot, Seq()) + } + + def init(name: String, + rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), + parameters: Seq[NamedRewrite.Parameter] = Seq(), + ): Rewrite = { + import arithexpr.{arithmetic => ae} + + val (typedLhs, freeV, freeT, typedRhs) = typeRule(rule, parameters) + + val patVars: PatternVarMap[Expr.Shift, PatternVar] = HashMap() + val natPatVars: PatternVarMap[Nat.Shift, NatPatternVar] = HashMap() + val dataTypePatVars: PatternVarMap[Type.Shift, DataTypePatternVar] = HashMap() + val typePatVars: PatternVarMap[Type.Shift, TypePatternVar] = HashMap() + val addrPatVars: PatternVarMap[Address.Shift, AddressPatternVar] = HashMap() + + // nats which we need to pivot to avoid matching over certain nat constructs + val natsToPivot = Vec[(rct.Nat, rct.NatIdentifier, Nat.Shift, NatPatternVar)]() + + val boundVarToShift = HashMap[String, Expr.Shift]() + + def shiftOfBound(bound: Expr.Bound): Expr.Shift = + (bound.expr.size, bound.nat.size, bound.data.size, bound.addr.size, bound.n2n.size) + + def natShiftOfBound(bound: Expr.Bound): Nat.Shift = + (bound.nat.size, bound.n2n.size) + + def typeShiftOfBound(bound: Expr.Bound): Type.Shift = + (bound.nat.size, bound.data.size, bound.n2n.size) + + def addrShiftOfBound(bound: Expr.Bound): Address.Shift = + bound.addr.size + + val lhsPat = makePat(typedLhs, Expr.Bound.empty, isRhs = false, + freeV, freeT, + shiftOfBound, natShiftOfBound, typeShiftOfBound, addrShiftOfBound, + patVars, natPatVars, dataTypePatVars, typePatVars, addrPatVars, + natsToPivot, boundVarToShift) + val rhsPat = makePat(typedRhs, Expr.Bound.empty, isRhs = true, + freeV, freeT, + shiftOfBound, natShiftOfBound, typeShiftOfBound, addrShiftOfBound, + patVars, natPatVars, dataTypePatVars, typePatVars, addrPatVars, + natsToPivot, boundVarToShift) + def patMkShift(s1: Expr.Shift, pv1: PatternVar) (s2: Expr.Shift, pv2: PatternVar) (applier: Applier): Applier = { @@ -389,116 +573,6 @@ object NamedRewrite { ??? } - // FIXME: duplicated from type inference's 'pivotSolution' - @scala.annotation.tailrec - def tryPivot(pivot: rct.NatIdentifier, n: rct.Nat, value: rct.Nat): Option[rct.Nat] = { - import arithexpr.arithmetic._ - - n match { - case i: rct.NatIdentifier if i == pivot => Some(value) - case Prod(terms) => - val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) - if (p.size != 1) { - None - } else { - tryPivot(pivot, p.head, rest.foldLeft(value)({ - case (v, r) => v /^ r - })) - } - case Sum(terms) => - val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) - if (p.size != 1) { - None - } else { - tryPivot(pivot, p.head, rest.foldLeft(value)({ - case (v, r) => v - r - })) - } - case Pow(b, Cst(-1)) => tryPivot(pivot, b, Cst(1) /^ value) - case Mod(p, m) if p == pivot => - val k = rct.NatIdentifier(s"_k_${p}_${m}", RangeAdd(0, PosInf, 1)) - Some(k*m + value) - case _ => None - } - } - - def pivotNatsRec(natsToPivot: Seq[(rct.Nat, rct.NatIdentifier, Nat.Shift, NatPatternVar)], - couldNotPivot: Seq[(rct.Nat, rct.NatIdentifier, Nat.Shift, NatPatternVar)]) - (applier: Applier): Applier = { - import arithexpr.arithmetic._ - - def pivotSuccess = pivotNatsRec(natsToPivot.tail ++ couldNotPivot, Seq())(applier) - def pivotFailure = pivotNatsRec(natsToPivot.tail, couldNotPivot :+ natsToPivot.head)(applier) - - natsToPivot.headOption match { - case Some((n, nv, shift, pv)) => - def fromNamed(n: rct.Nat): NatPattern = { - n match { - case i: rct.NatIdentifier => - makePatVar(i.name, shift, natPatVars, NatPatternVar, Unknown) - case PosInf => NatPatternNode(NatPosInf) - case NegInf => NatPatternNode(NatNegInf) - case Cst(c) => NatPatternNode(NatCst(c)) - case Sum(Nil) => NatPatternNode(NatCst(0)) - case Sum(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => - NatPatternNode(NatAdd(fromNamed(t), acc)) - } - case Prod(Nil) => NatPatternNode(NatCst(1)) - case Prod(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => - NatPatternNode(NatMul(fromNamed(t), acc)) - } - case Pow(b, e) => - NatPatternNode(NatPow(fromNamed(b), fromNamed(e))) - case Mod(a, b) => - NatPatternNode(NatMod(fromNamed(a), fromNamed(b))) - case IntDiv(a, b) => - NatPatternNode(NatIntDiv(fromNamed(a), fromNamed(b))) - case _ => throw new Exception(s"no support for $n") - } - } - - val natsToFindOut = HashMap[rct.NatIdentifier, Integer]().withDefault(_ => 0) - ArithExpr.visit(n, { - case ni: rct.NatIdentifier => - val isKnown = natPatVars.get(ni.name) - .exists(shiftMap => shiftMap.exists { case (s, (pv, status)) => status != Unknown }) - if (!isKnown) { - natsToFindOut(ni) += 1 - } - case _ => - }) - natsToFindOut.size match { - case 0 => // check nv = n - val valuePat = fromNamed(n) - val updateShifts = shiftAppliers(natPatVars, natPatMkShift, natPatMkShiftCheck) - updateShifts(ComputeNatCheckApplier(pv, valuePat, pivotSuccess)) - case 1 => - val (potentialPivot, uses) = natsToFindOut.head - if (uses == 1) { - tryPivot(potentialPivot, n, nv) match { - case Some(value) => - val valuePat = fromNamed(value) - val updateShifts = shiftAppliers(natPatVars, natPatMkShift, natPatMkShiftCheck) - val pivotPat = makePatVar(potentialPivot.name, shift, natPatVars, - NatPatternVar, Known) - updateShifts(ComputeNatApplier(pivotPat, valuePat, - shiftAppliers(natPatVars, natPatMkShift, natPatMkShiftCheck)(pivotSuccess))) - case None => pivotFailure - } - } else { - pivotFailure - } - case _ => pivotFailure - } - case None => - if (couldNotPivot.nonEmpty) { - throw new Exception(s"could not pivot nats: $couldNotPivot") - } else { - applier - } - } - } - val searcher: Searcher = lhsPat.compile() val param = parameters.foldRight((a: Applier) => a) { case (c, acc) => c match { @@ -529,8 +603,8 @@ object NamedRewrite { val shiftDTPV = shiftAppliers(dataTypePatVars, dataTypePatMkShift, dataTypePatMkShiftCheck) val shiftTPV = shiftAppliers(typePatVars, typePatMkShift, typePatMkShiftCheck) val shiftAPV = shiftAppliers(addrPatVars, addrPatMkShift, addrPatMkShiftCheck) - val pivotNats = pivotNatsRec(natsToPivot.toSeq, Seq()) _ - val applier = param(shiftPV(shiftNPV(shiftDTPV(shiftTPV(shiftAPV(pivotNats(rhsPat))))))) + val pivotNPV = pivotNats(natsToPivot.toSeq, natPatVars, natPatMkShift, natPatMkShiftCheck, ComputeNatCheckApplier, ComputeNatApplier) + val applier = param(shiftPV(shiftNPV(shiftDTPV(shiftTPV(shiftAPV(pivotNPV(rhsPat))))))) def allIsShiftCoherent[S, V](pvm: PatternVarMap[S, V]): Boolean = pvm.forall { case (_, shiftMap) => diff --git a/src/main/scala/rise/eqsat/Reggvolution.scala b/src/main/scala/rise/eqsat/Reggvolution.scala index d8707b8b4..72e455e7a 100644 --- a/src/main/scala/rise/eqsat/Reggvolution.scala +++ b/src/main/scala/rise/eqsat/Reggvolution.scala @@ -27,6 +27,7 @@ object Reggvolution { reggvolve(Pattern.fromExpr(expr)) // same as NamedRewrite.init, but flattens DeBruijn indices from different kinds. + // TODO: could factorize even more def reggvolveNamedRewrite( name: String, rule: (NamedRewriteDSL.Pattern, NamedRewriteDSL.Pattern), @@ -44,344 +45,56 @@ object Reggvolution { type FlatShift = Int val patVars: PatternVarMap[FlatShift, PatternVar] = HashMap() + val natPatVars: PatternVarMap[FlatShift, NatPatternVar] = HashMap() + val dataTypePatVars: PatternVarMap[FlatShift, DataTypePatternVar] = HashMap() + val typePatVars: PatternVarMap[FlatShift, TypePatternVar] = HashMap() + val addrPatVars: PatternVarMap[FlatShift, AddressPatternVar] = HashMap() // nats which we need to pivot to avoid matching over certain nat constructs val natsToPivot = Vec[(rct.Nat, rct.NatIdentifier, FlatShift, NatPatternVar)]() val boundVarToShift = HashMap[String, FlatShift]() - def makeExprPatVar( - name: String, - shift: FlatShift, - status: PatVarStatus - ): PatternVar = - makePatVar(name, shift, patVars, PatternVar, status) - - def makeOtherPatVar[V]( - name: String, - shift: FlatShift, - constructor: Int => V, - status: PatVarStatus, - ): V = - makeExprPatVar(name, shift, status) match { - case PatternVar(index) => constructor(index) - } - def shiftOfBound(bound: Expr.Bound): FlatShift = - // (bound.expr.size, bound.nat.size, bound.data.size, bound.addr.size, bound.n2n.size) bound.expr.size + bound.nat.size + bound.data.size + bound.addr.size + bound.n2n.size - def makePat(expr: rc.Expr, - bound: Expr.Bound, - isRhs: Boolean, - matchType: Boolean = true): Pattern = - Pattern(expr match { - case i: rc.Identifier if freeV.contains(i.name) => - makeExprPatVar(i.name, shiftOfBound(bound), if (isRhs) { Unknown } else { Known }) - case i: rc.Identifier => PatternNode(Var(bound.indexOf(i))) - - // note: we do not match for the type of lambda bodies, as we can always infer it: - // lam(x : xt, e : et) : xt -> et - case rc.Lambda(x, e) => - // right now we assume that all bound variables are uniquely named - if (!isRhs) { - assert(!boundVarToShift.contains(x.name)) - boundVarToShift += x.name -> shiftOfBound(bound) - } - PatternNode(Lambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(rct.NatKind, x: rct.NatIdentifier, e) => - PatternNode(NatLambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(rct.DataKind, x: rcdt.DataTypeIdentifier, e) => - PatternNode(DataLambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(rct.AddressSpaceKind, x: rct.AddressSpaceIdentifier, e) => - PatternNode(AddrLambda(makePat(e, bound + x, isRhs, matchType = false))) - case rc.DepLambda(_, _, _) => ??? - - case rc.App(rc.App(NamedRewriteDSL.Composition(_), f), g) => - PatternNode(Composition( - makePat(f, bound, isRhs, matchType = true), - makePat(g, bound, isRhs, matchType = false))) - - // note: we do not match for the type of applied functions, as we can always infer it: - // app(f : et -> at, e : et) : at - case rc.App(f, e) => - PatternNode(App(makePat(f, bound, isRhs, matchType = false), makePat(e, bound, isRhs))) - case rc.DepApp(rct.NatKind, f, x: rct.Nat) => - PatternNode(NatApp( - makePat(f, bound, isRhs, matchType = false), makeNPat(x, bound, isRhs))) - case rc.DepApp(rct.DataKind, f, x: rct.DataType) => - PatternNode(DataApp( - makePat(f, bound, isRhs, matchType = false), makeDTPat(x, bound, isRhs))) - case rc.DepApp(rct.AddressSpaceKind, f, x: rct.AddressSpace) => - PatternNode(AddrApp( - makePat(f, bound, isRhs, matchType = false), makeAPat(x, bound, isRhs))) - case rc.DepApp(_, _, _) => ??? - - case rc.Literal(rc.semantics.NatData(n)) => - PatternNode(NatLiteral(makeNPat(n, bound, isRhs))) - case rc.Literal(rc.semantics.IndexData(i, n)) => - PatternNode(IndexLiteral(makeNPat(i, bound, isRhs), makeNPat(n, bound, isRhs))) - case rc.Literal(d) => PatternNode(Literal(d)) - // note: we set the primitive type to a place holder here, - // because we do not want type information at the node level - case p: rc.Primitive => PatternNode(Primitive(p.setType(rct.TypePlaceholder))) - - case _ => ??? - }, if (!isRhs && !matchType) TypePatternAny else makeTPat(expr.t, bound, isRhs)) - - def makeNPat(n: rct.Nat, bound: Expr.Bound, isRhs: Boolean): NatPattern = - n match { - case i: rct.NatIdentifier if freeT(rct.NatKind.IDWrapper(i)) => - makeOtherPatVar(i.name, shiftOfBound(bound), - NatPatternVar, if (isRhs) { Unknown } else { Known }) - case i: rct.NatIdentifier => - NatPatternNode(NatVar(bound.indexOf(i))) - case ae.Cst(c) => - NatPatternNode(NatCst(c)) - case ae.Sum(Nil) => NatPatternNode(NatCst(0)) - case ae.Sum(t +: ts) if isRhs => ts.foldRight(makeNPat(t, bound, isRhs)) { case (t, acc) => - NatPatternNode(NatAdd(makeNPat(t, bound, isRhs), acc)) - } - case ae.Prod(Nil) => NatPatternNode(NatCst(1)) - case ae.Prod(t +: ts) if isRhs => ts.foldRight(makeNPat(t, bound, isRhs)) { case (t, acc) => - NatPatternNode(NatMul(makeNPat(t, bound, isRhs), acc)) - } - case ae.Pow(b, e) if isRhs => - NatPatternNode(NatPow(makeNPat(b, bound, isRhs), makeNPat(e, bound, isRhs))) - // do not match over these nat constructs on the left-hand side, - // as structural matching would not be sufficient, - // try to pivot the equality around a fresh pattern variable instead - case ae.Sum(_) | ae.Prod(_) | ae.Pow(_, _) if !isRhs => - val nv = rct.NatIdentifier(s"_nv${natsToPivot.size}") - val shift = shiftOfBound(bound) - val pv = makeOtherPatVar(nv.name, shift, NatPatternVar, Known) - natsToPivot.addOne((n, nv, shift, pv)) - pv - case _ => - throw new Exception(s"did not expect $n") - } - - def makeDTPat(dt: rct.DataType, bound: Expr.Bound, isRhs: Boolean): DataTypePattern = - dt match { - case i: rcdt.DataTypeIdentifier if freeT(IDWrapper(i)) => - makeOtherPatVar(i.name, shiftOfBound(bound), - DataTypePatternVar, if (isRhs) { Unknown } else { Known }) - case i: rcdt.DataTypeIdentifier => - DataTypePatternNode(DataTypeVar(bound.indexOf(i))) - case s: rcdt.ScalarType => - DataTypePatternNode(ScalarType(s)) - case rcdt.NatType => - DataTypePatternNode(NatType) - case rcdt.VectorType(s, et) => - DataTypePatternNode(VectorType(makeNPat(s, bound, isRhs), makeDTPat(et, bound, isRhs))) - case rcdt.IndexType(s) => - DataTypePatternNode(IndexType(makeNPat(s, bound, isRhs))) - case rcdt.PairType(dt1, dt2) => - DataTypePatternNode(PairType(makeDTPat(dt1, bound, isRhs), makeDTPat(dt2, bound, isRhs))) - case rcdt.ArrayType(s, et) => - DataTypePatternNode(ArrayType(makeNPat(s, bound, isRhs), makeDTPat(et, bound, isRhs))) - case _: rcdt.DepArrayType | _: rcdt.DepPairType[_, _] | - _: rcdt.NatToDataApply | _: rcdt.FragmentType | rcdt.ManagedBufferType(_) | rcdt.OpaqueType(_) => - throw new Exception(s"did not expect $dt") - } - - def makeTPat(t: rct.ExprType, bound: Expr.Bound, isRhs: Boolean): TypePattern = - t match { - case dt: rct.DataType => makeDTPat(dt, bound, isRhs) - case rct.FunType(a, b) => - TypePatternNode(FunType(makeTPat(a, bound, isRhs), makeTPat(b, bound, isRhs))) - case rct.DepFunType(rct.NatKind, x: rct.NatIdentifier, t) => - TypePatternNode(NatFunType(makeTPat(t, bound + x, isRhs))) - case rct.DepFunType(rct.DataKind, x: rcdt.DataTypeIdentifier, t) => - TypePatternNode(DataFunType(makeTPat(t, bound + x, isRhs))) - case rct.DepFunType(rct.AddressSpaceKind, x: rct.AddressSpaceIdentifier, t) => - TypePatternNode(AddrFunType(makeTPat(t, bound + x, isRhs))) - case rct.DepFunType(_, _, _) => ??? - case i: rct.TypeIdentifier => - assert(freeT(rct.TypeKind.IDWrapper(i))) - makeOtherPatVar(i.name, shiftOfBound(bound), - TypePatternVar, if (isRhs) { Unknown } else { Known }) - case rct.TypePlaceholder => - throw new Exception(s"did not expect $t, something was not infered") - } - - def makeAPat(a: rct.AddressSpace, bound: Expr.Bound, isRhs: Boolean): AddressPattern = - a match { - case i: rct.AddressSpaceIdentifier if freeT(rct.AddressSpaceKind.IDWrapper(i)) => - makeOtherPatVar(i.name, shiftOfBound(bound), - AddressPatternVar, if (isRhs) { Unknown } else { Known }) - case i: rct.AddressSpaceIdentifier => - AddressPatternNode(AddressVar(bound.indexOf(i))) - case rct.AddressSpace.Global => AddressPatternNode(Global) - case rct.AddressSpace.Local => AddressPatternNode(Local) - case rct.AddressSpace.Private => AddressPatternNode(Private) - case rct.AddressSpace.Constant => AddressPatternNode(Constant) - } - - val lhsPat = makePat(typedLhs, Expr.Bound.empty, isRhs = false) - val rhsPat = makePat(typedRhs, Expr.Bound.empty, isRhs = true) - - def shiftAppliers[S, V](pvm: PatternVarMap[S, V], - mkShift: (S, V) => (S, V) => String => String, - mkShiftCheck: (S, V) => (S, V) => String => String, - ): String => String = { - pvm.foldRight { a: String => a } { case ((name, shiftMap), acc) => - shiftMap.collectFirst { case (s, (v, ShiftCoherent)) => (s, v) } - // if nothing is shift coherent yet, pick any known shift as our reference - .orElse(shiftMap.collectFirst { case (s, (v, Known)) => (s, v) }) match { - case Some(base) => - shiftMap(base._1) = (base._2, ShiftCoherent) - - shiftMap.foldRight(acc) { case ((shift, (pv, status)), acc) => - status match { - // nothing to do - case ShiftCoherent => acc - // check a shifted variable - case Known => - shiftMap(shift) = (pv, ShiftCoherent) - a: String => acc(mkShiftCheck.tupled(base)(shift, pv)(a)) - // construct a shifted variable - case Unknown => - shiftMap(shift) = (pv, ShiftCoherent) - a: String => acc(mkShift.tupled(base)(shift, pv)(a)) - } - } - // nothing is known, but it may become known later (e.g. after nat pivoting) - case None => acc - } - } - } - - def patMkShift(s1: FlatShift, pv1: PatternVar) - (s2: FlatShift, pv2: PatternVar) + val lhsPat = makePat(typedLhs, Expr.Bound.empty, isRhs = false, + freeV, freeT, + shiftOfBound, shiftOfBound, shiftOfBound, shiftOfBound, + patVars, natPatVars, dataTypePatVars, typePatVars, addrPatVars, + natsToPivot, boundVarToShift) + val rhsPat = makePat(typedRhs, Expr.Bound.empty, isRhs = true, + freeV, freeT, + shiftOfBound, shiftOfBound, shiftOfBound, shiftOfBound, + patVars, natPatVars, dataTypePatVars, typePatVars, addrPatVars, + natsToPivot, boundVarToShift) + + def patMkShift(s1: FlatShift, pv1: Any) + (s2: FlatShift, pv2: Any) (applier: String): String = { assert(s1 != s2) val cutoff = s1 val shift = s2 - s1 - s"""{ shifted("${pv1}", "${pv2}", ${shift}, ${cutoff}, ${applier}) }""" + s"""shifted("${pv1}", "${pv2}", ${shift}, ${cutoff}, ${applier})""" } - def patMkShiftCheck(s1: FlatShift, pv1: PatternVar) - (s2: FlatShift, pv2: PatternVar) + def patMkShiftCheck(s1: FlatShift, pv1: Any) + (s2: FlatShift, pv2: Any) (applier: String): String = { assert(s1 != s2) val cutoff = s1 val shift = s2 - s1 - s"""{ shifted_check("${pv1}", "${pv2}", ${shift}, ${cutoff}, ${applier}) }""" + s"""shifted_check("${pv1}", "${pv2}", ${shift}, ${cutoff}, ${applier})""" } - // FIXME: duplicated from type inference's 'pivotSolution' - @scala.annotation.tailrec - def tryPivot(pivot: rct.NatIdentifier, n: rct.Nat, value: rct.Nat): Option[rct.Nat] = { - import arithexpr.arithmetic._ - - n match { - case i: rct.NatIdentifier if i == pivot => Some(value) - case Prod(terms) => - val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) - if (p.size != 1) { - None - } else { - tryPivot(pivot, p.head, rest.foldLeft(value)({ - case (v, r) => v /^ r - })) - } - case Sum(terms) => - val (p, rest) = terms.partition(t => ArithExpr.contains(t, pivot)) - if (p.size != 1) { - None - } else { - tryPivot(pivot, p.head, rest.foldLeft(value)({ - case (v, r) => v - r - })) - } - case Pow(b, Cst(-1)) => tryPivot(pivot, b, Cst(1) /^ value) - case Mod(p, m) if p == pivot => - val k = rct.NatIdentifier(s"_k_${p}_${m}", RangeAdd(0, PosInf, 1)) - Some(k*m + value) - case _ => None - } + def mkComputeNatCheck(pv: NatPatternVar, valuePat: NatPattern, applier: String): String = { + val vp = reggvolve(valuePat) + s"""compute_nat_check("${pv}", "${vp}", ${applier})""" } - def pivotNatsRec(natsToPivot: Seq[(rct.Nat, rct.NatIdentifier, FlatShift, NatPatternVar)], - couldNotPivot: Seq[(rct.Nat, rct.NatIdentifier, FlatShift, NatPatternVar)]) - (applier: String): String = { - import arithexpr.arithmetic._ - - def pivotSuccess = pivotNatsRec(natsToPivot.tail ++ couldNotPivot, Seq())(applier) - def pivotFailure = pivotNatsRec(natsToPivot.tail, couldNotPivot :+ natsToPivot.head)(applier) - - natsToPivot.headOption match { - case Some((n, nv, shift, pv)) => - def fromNamed(n: rct.Nat): NatPattern = { - n match { - case i: rct.NatIdentifier => - makeOtherPatVar(i.name, shift, NatPatternVar, Unknown) - case PosInf => NatPatternNode(NatPosInf) - case NegInf => NatPatternNode(NatNegInf) - case Cst(c) => NatPatternNode(NatCst(c)) - case Sum(Nil) => NatPatternNode(NatCst(0)) - case Sum(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => - NatPatternNode(NatAdd(fromNamed(t), acc)) - } - case Prod(Nil) => NatPatternNode(NatCst(1)) - case Prod(t +: ts) => ts.foldRight(fromNamed(t)) { case (t, acc) => - NatPatternNode(NatMul(fromNamed(t), acc)) - } - case Pow(b, e) => - NatPatternNode(NatPow(fromNamed(b), fromNamed(e))) - case Mod(a, b) => - NatPatternNode(NatMod(fromNamed(a), fromNamed(b))) - case IntDiv(a, b) => - NatPatternNode(NatIntDiv(fromNamed(a), fromNamed(b))) - case _ => throw new Exception(s"no support for $n") - } - } - - val natsToFindOut = HashMap[rct.NatIdentifier, Integer]().withDefault(_ => 0) - ArithExpr.visit(n, { - case ni: rct.NatIdentifier => - val isKnown = patVars.get(ni.name) - .exists(shiftMap => shiftMap.exists { case (s, (pv, status)) => status != Unknown }) - if (!isKnown) { - natsToFindOut(ni) += 1 - } - case _ => - }) - natsToFindOut.size match { - case 0 => // check nv = n - val valuePat = fromNamed(n) - val updateShifts = shiftAppliers(patVars, patMkShift, patMkShiftCheck) - val vp = reggvolve(valuePat) - updateShifts(s"""{ compute_nat_check("${pv}", "${vp}", ${pivotSuccess}) }""") - // ComputeNatCheckApplier(pv, valuePat, pivotSuccess)) - case 1 => - val (potentialPivot, uses) = natsToFindOut.head - if (uses == 1) { - tryPivot(potentialPivot, n, nv) match { - case Some(value) => - val valuePat = fromNamed(value) - val updateShifts = shiftAppliers(patVars, patMkShift, patMkShiftCheck) - val pivotPat = makeOtherPatVar(potentialPivot.name, shift, - NatPatternVar, Known) - val vp = reggvolve(valuePat) - val a = shiftAppliers(patVars, patMkShift, patMkShiftCheck)(pivotSuccess) - updateShifts(s"""{ compute_nat("${pivotPat}", "${vp}", ${a}) }""") - // ComputeNatApplier(pivotPat, valuePat, a) - case None => pivotFailure - } - } else { - pivotFailure - } - case _ => pivotFailure - } - case None => - if (couldNotPivot.nonEmpty) { - throw new Exception(s"could not pivot nats: $couldNotPivot") - } else { - applier - } - } + def mkComputeNat(pv: NatPatternVar, valuePat: NatPattern, applier: String): String = { + val vp = reggvolve(valuePat) + s"""compute_nat("${pv}", "${vp}", ${applier})""" } val searcher: String = s""""${reggvolve(lhsPat)}"""" @@ -399,32 +112,37 @@ object Reggvolution { case (s, (pv, Known)) => (s, pv) }.get val nfIndex = iS - nfShift // >= 0 because iS >= nfShift - (a: String) => s"""{ not_free_in("${iPV}", ${nfIndex}, ${a}) }""" + (a: String) => s"""not_free_in("${iPV}", ${nfIndex}, ${a})""" // NotFreeInApplier(iPV, nfIndex, acc(a)) case VectorizeScalarFun(f, n, fV) => - val (enPV, nST) = patVars(n)(0) - val nPV = enPV match { - case PatternVar(index) => NatPatternVar(index) - } + val (nPV, nST) = natPatVars(n)(0) assert(nST == Known) val (fPV, fST) = patVars(f)(0) assert(fST == Known) val fVPV = makePatVar(fV, 0, patVars, PatternVar, Known) - (a: String) => s"""{ vectorize_scalar_fun("${fPV}", ${nPV}, "${fVPV}", ${a}) }""" + (a: String) => s"""vectorize_scalar_fun("${fPV}", "${nPV}", "${fVPV}", ${a})""" // VectorizeScalarFunExtractApplier(fPV, nPV, fVPV, acc(a)) } } + val rhsPatApplier = s"""pat("${reggvolve(rhsPat)}")""" val shiftPV = shiftAppliers(patVars, patMkShift, patMkShiftCheck) - val pivotNats = pivotNatsRec(natsToPivot.toSeq, Seq()) _ - val rhsPatApplier = s""""${reggvolve(rhsPat)}"""" - val applier = param(shiftPV(pivotNats(rhsPatApplier))) + val shiftNPV = shiftAppliers(natPatVars, patMkShift, patMkShiftCheck) + val shiftDTPV = shiftAppliers(dataTypePatVars, patMkShift, patMkShiftCheck) + val shiftTPV = shiftAppliers(typePatVars, patMkShift, patMkShiftCheck) + val shiftAPV = shiftAppliers(addrPatVars, patMkShift, patMkShiftCheck) + val pivotNPV = pivotNats(natsToPivot.toSeq, natPatVars, patMkShift, patMkShiftCheck, mkComputeNatCheck, mkComputeNat) + val applier = param(shiftPV(shiftNPV(shiftDTPV(shiftTPV(shiftAPV(pivotNPV(rhsPatApplier))))))) def allIsShiftCoherent[S, V](pvm: PatternVarMap[S, V]): Boolean = pvm.forall { case (_, shiftMap) => shiftMap.forall { case (_, (_, status)) => status == ShiftCoherent }} assert(allIsShiftCoherent(patVars)) + assert(allIsShiftCoherent(natPatVars)) + assert(allIsShiftCoherent(dataTypePatVars)) + assert(allIsShiftCoherent(typePatVars)) + assert(allIsShiftCoherent(addrPatVars)) - s"""rewrite!("${name}"; ${searcher} => ${applier})""" + s"""rewrite!("${name}"; ${searcher} => { ${applier}) }""" } // DEPRECATED: From 6666649210e879dbf5be45ff05e48cd27785d9c6 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Thu, 20 Nov 2025 15:12:52 +0100 Subject: [PATCH 7/7] print more reggvolution --- src/main/scala/benchmarks/eqsat/mm.scala | 7 +++++-- src/main/scala/rise/eqsat/GuidedSearch.scala | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/main/scala/benchmarks/eqsat/mm.scala b/src/main/scala/benchmarks/eqsat/mm.scala index e17f11250..604936b20 100644 --- a/src/main/scala/benchmarks/eqsat/mm.scala +++ b/src/main/scala/benchmarks/eqsat/mm.scala @@ -587,7 +587,7 @@ object mm { containsMap(k, containsMap(cst(1)`.`vecT(cst(32), f32), ?))))))) - private def parallel_SRCL(): GuidedSearch.Result = { + def parallel_SRCL(): GuidedSearch.Result = { val start = mm // val start = apps.tvmGemm.arrayPacking(mm).get @@ -650,8 +650,12 @@ object mm { ) val rs = fs.map { case (n, f) => System.gc() // hint garbage collection to get more precise memory usage statistics + println(s"---- running $n search") (n, util.time(f())) } + + throw new Exception("Reggvolution done") + rs.foreach { case (n, (_, r)) => r.exprs.headOption.foreach(codegen(n, _)) } @@ -834,6 +838,5 @@ object Reggvolve { )) println("----") - throw new Exception("done") } } diff --git a/src/main/scala/rise/eqsat/GuidedSearch.scala b/src/main/scala/rise/eqsat/GuidedSearch.scala index 3e2b29953..a1ed5c6a5 100644 --- a/src/main/scala/rise/eqsat/GuidedSearch.scala +++ b/src/main/scala/rise/eqsat/GuidedSearch.scala @@ -148,6 +148,10 @@ class GuidedSearch( (0L, Seq()) } + for (e <- newBeam) { + println(Reggvolution.reggvolve(e)) + } + val totalIterationsTime = runner.iterations.iterator.map(_.totalTime).sum stats += GuidedSearch.Stats( initializeTime = initializeTime,