Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion src/main/scala/benchmarks/eqsat/mm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ object mm {
containsMap(k,
containsMap(cst(1)`.`vecT(cst(32), f32), ?)))))))

private def parallel_SRCL(): GuidedSearch.Result = {
def parallel_SRCL(): GuidedSearch.Result = {
val start = mm
// val start = apps.tvmGemm.arrayPacking(mm).get

Expand Down Expand Up @@ -621,6 +621,7 @@ object mm {
}

def main(args: Array[String]): Unit = {
Reggvolve.init_rules()
// val names = Set(args(0))
// fs.filter { case (k, _) => names(k) }

Expand Down Expand Up @@ -649,8 +650,12 @@ object mm {
)
val rs = fs.map { case (n, f) =>
System.gc() // hint garbage collection to get more precise memory usage statistics
println(s"---- running $n search")
(n, util.time(f()))
}

throw new Exception("Reggvolution done")

rs.foreach { case (n, (_, r)) =>
r.exprs.headOption.foreach(codegen(n, _))
}
Expand All @@ -663,3 +668,175 @@ object mm {
}
}
}

object Reggvolve {
def init_rules(): Unit = {
import rise.core.{primitives => rcp}
import rise.core.types.{Nat, DataType, AddressSpace}
import NamedRewriteDSL._

println(Reggvolution.reggvolve(Expr.fromNamed(benchmarks.eqsat.mm.mm)))

println("---- RULES ----")

println(Reggvolution.reggvolveNamedRewrite("map-fission",
app(map, lam("x", app("f", "gx" :: ("dt": DataType))))
-->
lam("in", app(app(map, "f"), app(app(map, lam("x", "gx")), "in"))),
Seq("f" notFree "x")
))
println(Reggvolution.reggvolveNamedRewrite("reduce-seq",
reduce --> rcp.reduceSeq.primitive
))
println(Reggvolution.reggvolveNamedRewrite("eliminate-map-identity",
app(map, lam("x", "x"))
-->
lam("y", "y")
))
println(Reggvolution.reggvolveNamedRewrite("reduce-seq-map-fusion",
app(app(app(rcp.reduceSeq.primitive, "f"), "init"), app(app(map, "g"), "in"))
-->
app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("x",
app(app("f", "acc"), app("g", "x"))))), "init"), "in")
))
def splitJoin(n: Int) = Reggvolution.reggvolveNamedRewrite(s"split-join-$n",
app(app(map, "f"), "in")
-->
app(join, app(app(map, app(map, "f")), app(nApp(split, n), "in")))
)
println(splitJoin(32))
def splitJoin2M(n: Int) = Reggvolution.reggvolveNamedRewrite(s"split-join-2m-$n",
app(app(map, app(map, app(map, "f"))), "in")
-->
app(app(map, app(map, join)), app(app(map, app(map, app(map, app(map, "f")))), app(app(map, app(map, nApp(split, n))), "in")))
)
println(splitJoin2M(32))
def blockedReduce(n: Int) = Reggvolution.reggvolveNamedRewrite(s"blocked-reduce-$n",
app(app(app(reduce, "op" :: ("a" ->: "a" ->: t("a"))), "init"), "arg")
-->
app(app(app(rcp.reduceSeq.primitive,
lam("acc", lam("y", app(app("op", "acc"),
app(app(app(reduce, "op"), "init"), "y"))))),
"init"), app(nApp(split, n), "arg"))
)
println(blockedReduce(4))
println(Reggvolution.reggvolveNamedRewrite("split-before-map",
app(nApp(split, "n"), app(app(map, "f"), "in"))
-->
app(app(map, app(map, "f")), app(nApp(split, "n"), "in"))
))
println(Reggvolution.reggvolveNamedRewrite("reduce-seq-map-fission",
app(app(rcp.reduceSeq.primitive, lam("acc", lam("y",
app(app("op", "acc"), "gy" :: ("dt": DataType))))), "init")
-->
lam("in", app(app(app(rcp.reduceSeq.primitive, "op"), "init"),
app(app(map, lam("y", "gy")), "in"))),
Seq("op" notFree "y")
))
println(Reggvolution.reggvolveNamedRewrite("lift-reduce-seq",
app(map, app(app(rcp.reduceSeq.primitive, "op"), "init"))
-->
lam("in",
app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("y",
app(app(map, lam("z", app(app("op", app(fst, "z")), app(snd, "z")))),
app(app(zip, "acc"), "y"))
))), app(rcp.generate.primitive, lam("i", "init"))),
app(transpose, "in")))
))
println(Reggvolution.reggvolveNamedRewrite("lift-reduce-seq-2",
app(map, lam("x", app(app(add, app(fst, "x")),
app(app(app(rcp.reduceSeq.primitive, "op"), lf32(0)), app(snd, "x")))))
-->
lam("in", app(lam("uz",
app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("y",
app(app(map, lam("z", app(app("op", app(fst, "z")), app(snd, "z")))),
app(app(zip, "acc"), "y"))
))), app(fst, "uz")), app(transpose, app(snd, "uz")))),
app(unzip, "in")))
))
println(Reggvolution.reggvolveNamedRewrite("lift-reduce-seq-3",
(app(map, lam("x",
app(app(app(rcp.reduceSeq.primitive, "op"), app(fst, app(unzip, "x"))),
app(transpose, app(snd, app(unzip, "x")))))))
-->
lam("in",
app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("y",
app(app(map, lam("z", app(app("op", app(fst, "z")), app(snd, "z")))),
app(app(zip, "acc"), "y"))
))), app(fst, app(unzip, app(app(map, unzip), "in")))),
app(transpose, app(app(map, transpose), app(snd, app(unzip, app(app(map, unzip), "in"))))))),
Seq("op" notFree "x")
))
println(Reggvolution.reggvolveNamedRewrite("transpose-around-map-map-f-1m",
app(app(map, app(map, app(map, "f"))), "in")
-->
app(app(map, transpose), app(app(map, app(map, app(map, "f"))), app(app(map, transpose), "in")))
))
println(Reggvolution.reggvolveNamedRewrite("store-to-mem",
("in" :: ("dt": DataType))
-->
app(app(rcp.let.primitive, app(rcp.toMem.primitive, "in")), lam("x", "x"))
))
def splitJoin2(n: Int) = Reggvolution.reggvolveNamedRewrite(s"split-join-2-$n",
("in" :: (`?n``.``?dt`))
-->
app(join, app(nApp(split, n), "in"))
)
splitJoin2(32)
println(Reggvolution.reggvolveNamedRewrite("map-array",
("x" :: (`?n``.``?dt`))
-->
app(app(map, lam("y", "y")), "x")
))
println(Reggvolution.reggvolveNamedRewrite("map-fusion",
app(app(map, "f"), app(app(map, "g"), "in"))
-->
app(app(map, lam("x", app("f", app("g", "x")))), "in")
))
println(Reggvolution.reggvolveNamedRewrite("map-eta-abstraction",
app(map, "f") --> lam("x", app(app(map, "f"), "x"))
))
object vectorize {
def after(n: Int, dt: DataType) = Reggvolution.reggvolveNamedRewrite(s"vec-$n-after-$dt",
// TODO: if m % n == 0 ?
("e" :: (("m": Nat)`.`dt))
-->
app(asScalar, app(nApp(asVector, n), "e"))
)
}
println(vectorize.after(32, f32))
println(vectorize.after(32, f32 x f32))
println(vectorize.after(32, f32 x (f32 x f32)))
println(Reggvolution.reggvolveNamedRewrite("vec-before-map-f32",
app(nApp(asVector, "n"), app(app(map, "f" :: f32 ->: f32), ("in": Pattern)))
-->
app(app(map, "fV"), app(nApp(asVector, "n"), "in")),
Seq(vectorizeScalarFun("f", "n", "fV"))
))
println(Reggvolution.reggvolveNamedRewrite("vec-before-map-f32xf32",
app(nApp(asVector, "n"), app(app(map, "f" :: (f32 x f32) ->: f32), ("in": Pattern)))
-->
app(app(map, "fV"),
app(app(zip, app(nApp(asVector, "n"), app(fst, app(unzip, "in")))),
app(nApp(asVector, "n"), app(snd, app(unzip, "in"))))),
Seq(vectorizeScalarFun("f", "n", "fV"))
))
println(Reggvolution.reggvolveNamedRewrite("vec-before-map-f32x-f32xf32",
app(nApp(asVector, "n"), app(app(map, "f" :: (f32 x (f32 x f32)) ->: f32), ("in": Pattern)))
-->
app(app(map, "fV"),
app(app(zip, app(nApp(asVector, "n"), app(fst, app(unzip, "in")))),
app(app(zip, app(nApp(asVector, "n"), app(fst, app(unzip, app(snd, app(unzip, "in")))))),
app(nApp(asVector, "n"), app(snd, app(unzip, app(snd, app(unzip, "in")))))))),
Seq(vectorizeScalarFun("f", "n", "fV"))
))
println(Reggvolution.reggvolveNamedRewrite("reduce-seq-unroll",
rcp.reduceSeq.primitive --> rcp.reduceSeqUnroll.primitive
))
println(Reggvolution.reggvolveNamedRewrite("map-par",
map --> rise.openMP.primitives.mapPar.primitive
))

println("----")
}
}
4 changes: 4 additions & 0 deletions src/main/scala/rise/eqsat/GuidedSearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ class GuidedSearch(
(0L, Seq())
}

for (e <- newBeam) {
println(Reggvolution.reggvolve(e))
}

val totalIterationsTime = runner.iterations.iterator.map(_.totalTime).sum
stats += GuidedSearch.Stats(
initializeTime = initializeTime,
Expand Down
Loading