Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions hail/src/main/scala/is/hail/expr/ir/AggOp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,6 @@ case class AggStatePhysicalSignature(m: Map[AggOp, PhysicalAggSignature], defaul
def lookup(op: AggOp): PhysicalAggSignature = m(op)
}

object PhysicalAggSignature {
def apply(op: AggOp,
physicalInitOpArgs: Seq[PType],
physicalSeqOpArgs: Seq[PType],
nested: Option[Seq[AggStatePhysicalSignature]]
): PhysicalAggSignature = PhysicalAggSignature(op, physicalInitOpArgs, physicalSeqOpArgs, nested)
}

object AggStatePhysicalSignature {
def apply(sig: PhysicalAggSignature): AggStatePhysicalSignature = AggStatePhysicalSignature(Map(sig.op -> sig), sig.op)
}
Expand All @@ -70,6 +62,8 @@ case class PhysicalAggSignature(
def seqOpArgs: Seq[Type] = physicalSeqOpArgs.map(_.virtualType)

lazy val virtual: AggSignature = AggSignature(op, physicalInitOpArgs.map(_.virtualType), physicalSeqOpArgs.map(_.virtualType))
lazy val singletonContainer: AggStatePhysicalSignature = AggStatePhysicalSignature(Map(op -> this), op, None)

}

sealed trait AggOp {}
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ abstract class BaseIR {

def deepCopy(): this.type = copy(newChildren = children.map(_.deepCopy())).asInstanceOf[this.type]

lazy val noSharing: this.type = if (HasIRSharing(this)) this.deepCopy() else this

def mapChildren(f: (BaseIR) => BaseIR): BaseIR = {
val newChildren = children.map(f)
if ((children, newChildren).zipped.forall(_ eq _))
Expand Down
17 changes: 8 additions & 9 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ object CompileWithAggregators2 {

private def apply[F >: Null : TypeInfo, R: TypeInfo : ClassTag](
ctx: ExecuteContext,
aggSigs: Array[AggStateSignature],
aggSigs: Array[AggStatePhysicalSignature],
args: Seq[(String, PType, ClassTag[_])],
argTypeInfo: Array[MaybeGenericTypeInfo[_]],
body: IR,
Expand All @@ -233,8 +233,7 @@ object CompileWithAggregators2 {
val normalizeNames = new NormalizeNames(_.toString)
val normalizedBody = normalizeNames(body,
Env(args.map { case (n, _, _) => n -> n }: _*))
val pAggSigs = aggSigs.map(_.toCanonicalPhysical)
val k = CodeCacheKey(pAggSigs.toFastIndexedSeq, args.map { case (n, pt, _) => (n, pt) }, normalizedBody)
val k = CodeCacheKey(aggSigs, args.map { case (n, pt, _) => (n, pt) }, normalizedBody)
codeCache.get(k) match {
case Some(v) =>
return (v.typ, v.f.asInstanceOf[(Int, Region) => (F with FunctionWithAggRegion)])
Expand All @@ -251,11 +250,11 @@ object CompileWithAggregators2 {

TypeCheck(ir, BindingEnv(Env.fromSeq[Type](args.map { case (name, t, _) => name -> t.virtualType })))

InferPType(if(HasIRSharing(ir)) ir.deepCopy() else ir, Env(args.map { case (n, pt, _) => n -> pt}: _*))
InferPType(ir.noSharing, Env(args.map { case (n, pt, _) => n -> pt}: _*), aggSigs, null, null)

assert(TypeToIRIntermediateClassTag(ir.typ) == classTag[R])

Emit(ctx, ir, fb, Some(pAggSigs))
Emit(ctx, ir, fb, Some(aggSigs))

val f = fb.resultWithIndex()
codeCache += k -> CodeCacheValue(ir.pType, f)
Expand All @@ -264,7 +263,7 @@ object CompileWithAggregators2 {

def apply[F >: Null : TypeInfo, R: TypeInfo : ClassTag](
ctx: ExecuteContext,
aggSigs: Array[AggStateSignature],
aggSigs: Array[AggStatePhysicalSignature],
args: Seq[(String, PType, ClassTag[_])],
body: IR,
optimize: Boolean
Expand All @@ -285,15 +284,15 @@ object CompileWithAggregators2 {

def apply[R: TypeInfo : ClassTag](
ctx: ExecuteContext,
aggSigs: Array[AggStateSignature],
aggSigs: Array[AggStatePhysicalSignature],
body: IR): (PType, (Int, Region) => AsmFunction1[Region, R] with FunctionWithAggRegion) = {

apply[AsmFunction1[Region, R], R](ctx, aggSigs, FastSeq[(String, PType, ClassTag[_])](), body, optimize = true)
}

def apply[T0: ClassTag, R: TypeInfo : ClassTag](
ctx: ExecuteContext,
aggSigs: Array[AggStateSignature],
aggSigs: Array[AggStatePhysicalSignature],
name0: String, typ0: PType,
body: IR): (PType, (Int, Region) => AsmFunction3[Region, T0, Boolean, R] with FunctionWithAggRegion) = {

Expand All @@ -302,7 +301,7 @@ object CompileWithAggregators2 {

def apply[T0: ClassTag, T1: ClassTag, R: TypeInfo : ClassTag](
ctx: ExecuteContext,
aggSigs: Array[AggStateSignature],
aggSigs: Array[AggStatePhysicalSignature],
name0: String, typ0: PType,
name1: String, typ1: PType,
body: IR): (PType, (Int, Region) => (AsmFunction5[Region, T0, Boolean, T1, Boolean, R] with FunctionWithAggRegion)) = {
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ final case class ArrayAggScan(a: IR, name: String, query: IR) extends IR
trait InferredPhysicalAggSignature {
// will be filled in by InferPType in subsequent PR
def signature: IndexedSeq[AggStateSignature]
var physicalSignatures2: Array[AggStatePhysicalSignature] = _
val physicalSignatures: Array[AggStatePhysicalSignature] = signature.map(_.toCanonicalPhysical).toArray
}
final case class RunAgg(body: IR, result: IR, signature: IndexedSeq[AggStateSignature]) extends IR with InferredPhysicalAggSignature
Expand Down
189 changes: 176 additions & 13 deletions hail/src/main/scala/is/hail/expr/ir/InferPType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,66 @@ import is.hail.expr.types.virtual.{TNDArray, TVoid}
import is.hail.utils._

object InferPType {

def clearPTypes(x: IR): Unit = {
x._pType2 = null
x.children.foreach { c => clearPTypes(c.asInstanceOf[IR]) }
}

// does not unify physical arg types if multiple nested seq/init ops appear; instead takes the first. The emitter checks equality.
def computePhysicalAgg(virt: AggStateSignature, initsAB: ArrayBuilder[RecursiveArrayBuilderElement[InitOp]],
seqAB: ArrayBuilder[RecursiveArrayBuilderElement[SeqOp]]): AggStatePhysicalSignature = {
val inits = initsAB.result()
val seqs = seqAB.result()
assert(inits.nonEmpty)
assert(seqs.nonEmpty)
virt.default match {
case AggElementsLengthCheck() =>

val iHead = inits.find(_.value.op == AggElementsLengthCheck()).get
val iNested = iHead.nested.get
val iHeadArgTypes = iHead.value.args.map(_.pType2)

val sLCHead = seqs.find(_.value.op == AggElementsLengthCheck()).get
val sLCArgTypes = sLCHead.value.args.map(_.pType2)
val sAEHead = seqs.find(_.value.op == AggElements()).get
val sNested = sAEHead.nested.get
val sHeadArgTypes = sAEHead.value.args.map(_.pType2)

val vNested = virt.nested.get.toArray

val nested = vNested.indices.map { i => computePhysicalAgg(vNested(i), iNested(i), sNested(i)) }
AggStatePhysicalSignature(Map(
AggElementsLengthCheck() -> PhysicalAggSignature(AggElementsLengthCheck(), iHeadArgTypes, sLCArgTypes),
AggElements() -> PhysicalAggSignature(AggElements(), FastIndexedSeq(), sHeadArgTypes)
), AggElementsLengthCheck(), Some(nested))

case Group() =>
val iHead = inits.head
val iNested = iHead.nested.get
val iHeadArgTypes = iHead.value.args.map(_.pType2)

val sHead = seqs.head
val sNested = sHead.nested.get
val sHeadArgTypes = sHead.value.args.map(_.pType2)

val vNested = virt.nested.get.toArray

val nested = vNested.indices.map { i => computePhysicalAgg(vNested(i), iNested(i), sNested(i)) }
val psig = PhysicalAggSignature(Group(), iHeadArgTypes, sHeadArgTypes)
AggStatePhysicalSignature(Map(Group() -> psig), Group(), Some(nested))

case _ =>
assert(inits.forall(_.nested.isEmpty))
assert(seqs.forall(_.nested.isEmpty))
val iHead = inits.head.value
val iHeadArgTypes = iHead.args.map(_.pType2)
val sHead = seqs.head.value
val sHeadArgTypes = sHead.args.map(_.pType2)
virt.defaultSignature.toPhysical(iHeadArgTypes, sHeadArgTypes).singletonContainer
}
}

def getNestedElementPTypes(ptypes: Seq[PType]): PType = {
assert(ptypes.forall(_.virtualType.isOfType(ptypes.head.virtualType)))
getNestedElementPTypesOfSameType(ptypes: Seq[PType])
Expand Down Expand Up @@ -40,10 +100,20 @@ object InferPType {
}
}

def apply(ir: IR, env: Env[PType]): Unit = {
assert(ir._pType2 == null)
def apply(ir: IR, env: Env[PType]): Unit = apply(ir, env, null, null, null)

private type AAB[T] = Array[ArrayBuilder[RecursiveArrayBuilderElement[T]]]

case class RecursiveArrayBuilderElement[T](value: T, nested: Option[AAB[T]])

def newBuilder[T](n: Int): AAB[T] = Array.fill(n)(new ArrayBuilder[RecursiveArrayBuilderElement[T]])

def infer(ir: IR, env: Env[PType] = env): Unit = apply(ir, env)
def apply(ir: IR, env: Env[PType], aggs: Array[AggStatePhysicalSignature], inits: AAB[InitOp], seqs: AAB[SeqOp]): Unit = {
if (ir._pType2 != null)
throw new RuntimeException(ir.toString)

def infer(ir: IR, env: Env[PType] = env, aggs: Array[AggStatePhysicalSignature] = aggs,
inits: AAB[InitOp] = inits, seqs: AAB[SeqOp] = seqs): Unit = apply(ir, env, aggs, inits, seqs)

ir._pType2 = ir match {
case I32(_) => PInt32(true)
Expand Down Expand Up @@ -206,6 +276,11 @@ object InferPType {
assert(body.pType2 isOfType zero.pType2)

zero.pType2.setRequired(body.pType2.required)
case ArrayFor(a, value, body) =>
infer(a)

infer(body, env.bind(value -> a.pType2.asInstanceOf[PStream].elementType))
PVoid
case ArrayFold2(a, acc, valueName, seq, res) =>
infer(a)
acc.foreach { case (_, accIR) => infer(accIR) }
Expand All @@ -226,7 +301,7 @@ object InferPType {
case ArrayLeftJoinDistinct(lIR, rIR, lName, rName, compare, join) =>
infer(lIR)
infer(rIR)
val e = env.bind(lName -> lIR.pType2.asInstanceOf[PStream].elementType, rName -> rIR.pType2.asInstanceOf[PStream].elementType)
val e = env.bind(lName -> lIR.pType2.asInstanceOf[PStream].elementType, rName -> rIR.pType2.asInstanceOf[PStream].elementType.setRequired(false))

infer(compare, e)
infer(join, e)
Expand Down Expand Up @@ -364,13 +439,6 @@ object InferPType {
theIR._pType2
}))
case In(_, pType: PType) => pType
case ArrayFor(a, valueName, body) =>
infer(a)
infer(body, env.bind(valueName -> a._pType2.asInstanceOf[PStream].elementType))
PVoid
case x if x.typ == TVoid =>
x.children.foreach(c => infer(c.asInstanceOf[IR]))
PVoid
case CollectDistributedArray(contextsIR, globalsIR, contextsName, globalsName, bodyIR) =>
infer(contextsIR)
infer(globalsIR)
Expand Down Expand Up @@ -402,10 +470,105 @@ object InferPType {
val allReq = rPTypes.forall(f => f.typ.required)
PCanonicalTuple(rPTypes, allReq)

case _: AggLet | _: RunAgg | _: RunAggScan | _: NDArrayAgg | _: AggFilter | _: AggExplode |
_: AggGroupBy | _: AggArrayPerElement | _: ApplyAggOp | _: ApplyScanOp | _: AggStateValue => PType.canonical(ir.typ)
case x@InitOp(i, args, sig, op) =>
op match {
case Group() =>
val nested = sig.nested.get
val newInits = newBuilder[InitOp](nested.length)
val IndexedSeq(initArg) = args
infer(initArg, env, null, inits = newInits, seqs = null)
if (inits != null)
inits(i) += RecursiveArrayBuilderElement(x, Some(newInits))
case AggElementsLengthCheck() =>
val nested = sig.nested.get
val newInits = newBuilder[InitOp](nested.length)
val initArg = args match {
case Seq(len, initArg) =>
infer(len, env, null, null, null)
initArg
case Seq(initArg) => initArg
}
infer(initArg, env, null, inits = newInits, seqs = null)
if (inits != null)
inits(i) += RecursiveArrayBuilderElement(x, Some(newInits))
case _ =>
assert(sig.nested.isEmpty)
args.foreach(infer(_, env, null, null, null))
if (inits != null)
inits(i) += RecursiveArrayBuilderElement(x, None)
}
PVoid

case x@SeqOp(i, args, sig, op) =>
op match {
case Group() =>
val nested = sig.nested.get
val newSeqs = newBuilder[SeqOp](nested.length)
val IndexedSeq(k, seqArg) = args
infer(k, env, null, inits = null, seqs = null)
infer(seqArg, env, null, inits = null, seqs = newSeqs)
if (seqs != null)
seqs(i) += RecursiveArrayBuilderElement(x, Some(newSeqs))
case AggElements() =>
val nested = sig.nested.get
val newSeqs = newBuilder[SeqOp](nested.length)
val IndexedSeq(idx, seqArg) = args
infer(idx, env, null, inits = null, seqs = null)
infer(seqArg, env, null, inits = null, seqs = newSeqs)
if (seqs != null)
seqs(i) += RecursiveArrayBuilderElement(x, Some(newSeqs))
case AggElementsLengthCheck() =>
val nested = sig.nested.get
val IndexedSeq(idx) = args
infer(idx, env, null, inits = null, seqs = null)
if (seqs != null)
seqs(i) += RecursiveArrayBuilderElement(x, None)
case _ =>
assert(sig.nested.isEmpty)
args.foreach(infer(_, env, null, null, null))
if (seqs != null)
seqs(i) += RecursiveArrayBuilderElement(x, None)
}
PVoid

case x@ResultOp(resultIdx, sigs) =>
PCanonicalTuple(true, (resultIdx until resultIdx + sigs.length).map(i => aggs(i).resultType): _*)

case x@RunAgg(body, result, signature) =>
val inits = newBuilder[InitOp](signature.length)
val seqs = newBuilder[SeqOp](signature.length)
infer(body, env, inits = inits, seqs = seqs, aggs = null)
val sigs = signature.indices.map { i => computePhysicalAgg(signature(i), inits(i), seqs(i)) }.toArray
infer(result, env, aggs = sigs, inits = null, seqs = null)
x.physicalSignatures2 = sigs
result.pType2

case x@RunAggScan(array, name, init, seq, result, signature) =>
infer(array)
val e2 = env.bind(name, coerce[PStreamable](array.pType2).elementType)
val inits = newBuilder[InitOp](signature.length)
val seqs = newBuilder[SeqOp](signature.length)
infer(init, env = e2, inits = inits, seqs = null, aggs = null)
infer(seq, env = e2, inits = null, seqs = seqs, aggs = null)
val sigs = signature.indices.map { i => computePhysicalAgg(signature(i), inits(i), seqs(i)) }.toArray
infer(result, env = e2, aggs = sigs, inits = null, seqs = null)
x.physicalSignatures2 = sigs
PCanonicalArray(result.pType2, array._pType2.required)

case AggStateValue(i, sig) => PCanonicalBinary(true)
case x if x.typ == TVoid =>
x.children.foreach(c => infer(c.asInstanceOf[IR]))
PVoid

case NDArrayAgg(nd, _) =>
infer(nd)
PType.canonical(ir.typ)
case x if x.typ == TVoid =>
x.children.foreach(c => infer(c.asInstanceOf[IR]))
PVoid
}


// Allow only requiredeness to diverge
assert(ir.pType2.virtualType isOfType ir.typ)
}
Expand Down
18 changes: 12 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -613,23 +613,29 @@ object Interpret {
} else {
val spec = BufferSpec.defaultUncompressed

val physicalAggs = extracted.getPhysicalAggs(
ctx,
Env("global" -> value.globals.t),
Env("global" -> value.globals.t, "row" -> value.rvd.rowPType)
)

val (_, initOp) = CompileWithAggregators2[Long, Unit](ctx,
extracted.aggs,
physicalAggs,
"global", value.globals.t,
extracted.init)

val (_, partitionOpSeq) = CompileWithAggregators2[Long, Long, Unit](ctx,
extracted.aggs,
physicalAggs,
"global", value.globals.t,
"row", value.rvd.rowPType,
extracted.seqPerElt)

val read = extracted.deserialize(ctx, spec)
val write = extracted.serialize(ctx, spec)
val combOpF = extracted.combOpF(ctx, spec)
val read = extracted.deserialize(ctx, spec, physicalAggs)
val write = extracted.serialize(ctx, spec, physicalAggs)
val combOpF = extracted.combOpF(ctx, spec, physicalAggs)

val (rTyp: PTuple, f) = CompileWithAggregators2[Long, Long](ctx,
extracted.aggs,
physicalAggs,
"global", value.globals.t,
Let(res, extracted.results, MakeTuple.ordered(FastSeq(extracted.postAggIR))))
assert(rTyp.types(0).virtualType == query.typ)
Expand Down
Loading