diff --git a/hail/src/main/scala/is/hail/expr/ir/AggOp.scala b/hail/src/main/scala/is/hail/expr/ir/AggOp.scala index 6437476f6a6..f18d0f69894 100644 --- a/hail/src/main/scala/is/hail/expr/ir/AggOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/AggOp.scala @@ -49,14 +49,6 @@ case class AggStatePhysicalSignature(m: Map[AggOp, PhysicalAggSignature], defaul def lookup(op: AggOp): PhysicalAggSignature = m(op) } -object PhysicalAggSignature { - def apply(op: AggOp, - physicalInitOpArgs: Seq[PType], - physicalSeqOpArgs: Seq[PType], - nested: Option[Seq[AggStatePhysicalSignature]] - ): PhysicalAggSignature = PhysicalAggSignature(op, physicalInitOpArgs, physicalSeqOpArgs, nested) -} - object AggStatePhysicalSignature { def apply(sig: PhysicalAggSignature): AggStatePhysicalSignature = AggStatePhysicalSignature(Map(sig.op -> sig), sig.op) } @@ -70,6 +62,8 @@ case class PhysicalAggSignature( def seqOpArgs: Seq[Type] = physicalSeqOpArgs.map(_.virtualType) lazy val virtual: AggSignature = AggSignature(op, physicalInitOpArgs.map(_.virtualType), physicalSeqOpArgs.map(_.virtualType)) + lazy val singletonContainer: AggStatePhysicalSignature = AggStatePhysicalSignature(Map(op -> this), op, None) + } sealed trait AggOp {} diff --git a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala index 6afc59fbd62..89f9a1eaa74 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala @@ -11,6 +11,8 @@ abstract class BaseIR { def deepCopy(): this.type = copy(newChildren = children.map(_.deepCopy())).asInstanceOf[this.type] + lazy val noSharing: this.type = if (HasIRSharing(this)) this.deepCopy() else this + def mapChildren(f: (BaseIR) => BaseIR): BaseIR = { val newChildren = children.map(f) if ((children, newChildren).zipped.forall(_ eq _)) diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index c75a90de932..192e769b93f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -224,7 +224,7 @@ object CompileWithAggregators2 { private def apply[F >: Null : TypeInfo, R: TypeInfo : ClassTag]( ctx: ExecuteContext, - aggSigs: Array[AggStateSignature], + aggSigs: Array[AggStatePhysicalSignature], args: Seq[(String, PType, ClassTag[_])], argTypeInfo: Array[MaybeGenericTypeInfo[_]], body: IR, @@ -233,8 +233,7 @@ object CompileWithAggregators2 { val normalizeNames = new NormalizeNames(_.toString) val normalizedBody = normalizeNames(body, Env(args.map { case (n, _, _) => n -> n }: _*)) - val pAggSigs = aggSigs.map(_.toCanonicalPhysical) - val k = CodeCacheKey(pAggSigs.toFastIndexedSeq, args.map { case (n, pt, _) => (n, pt) }, normalizedBody) + val k = CodeCacheKey(aggSigs, args.map { case (n, pt, _) => (n, pt) }, normalizedBody) codeCache.get(k) match { case Some(v) => return (v.typ, v.f.asInstanceOf[(Int, Region) => (F with FunctionWithAggRegion)]) @@ -251,11 +250,11 @@ object CompileWithAggregators2 { TypeCheck(ir, BindingEnv(Env.fromSeq[Type](args.map { case (name, t, _) => name -> t.virtualType }))) - InferPType(if(HasIRSharing(ir)) ir.deepCopy() else ir, Env(args.map { case (n, pt, _) => n -> pt}: _*)) + InferPType(ir.noSharing, Env(args.map { case (n, pt, _) => n -> pt}: _*), aggSigs, null, null) assert(TypeToIRIntermediateClassTag(ir.typ) == classTag[R]) - Emit(ctx, ir, fb, Some(pAggSigs)) + Emit(ctx, ir, fb, Some(aggSigs)) val f = fb.resultWithIndex() codeCache += k -> CodeCacheValue(ir.pType, f) @@ -264,7 +263,7 @@ object CompileWithAggregators2 { def apply[F >: Null : TypeInfo, R: TypeInfo : ClassTag]( ctx: ExecuteContext, - aggSigs: Array[AggStateSignature], + aggSigs: Array[AggStatePhysicalSignature], args: Seq[(String, PType, ClassTag[_])], body: IR, optimize: Boolean @@ -285,7 +284,7 @@ object CompileWithAggregators2 { def apply[R: TypeInfo : ClassTag]( ctx: ExecuteContext, - aggSigs: Array[AggStateSignature], + aggSigs: Array[AggStatePhysicalSignature], body: IR): (PType, (Int, Region) => AsmFunction1[Region, R] with FunctionWithAggRegion) = { apply[AsmFunction1[Region, R], R](ctx, aggSigs, FastSeq[(String, PType, ClassTag[_])](), body, optimize = true) @@ -293,7 +292,7 @@ object CompileWithAggregators2 { def apply[T0: ClassTag, R: TypeInfo : ClassTag]( ctx: ExecuteContext, - aggSigs: Array[AggStateSignature], + aggSigs: Array[AggStatePhysicalSignature], name0: String, typ0: PType, body: IR): (PType, (Int, Region) => AsmFunction3[Region, T0, Boolean, R] with FunctionWithAggRegion) = { @@ -302,7 +301,7 @@ object CompileWithAggregators2 { def apply[T0: ClassTag, T1: ClassTag, R: TypeInfo : ClassTag]( ctx: ExecuteContext, - aggSigs: Array[AggStateSignature], + aggSigs: Array[AggStatePhysicalSignature], name0: String, typ0: PType, name1: String, typ1: PType, body: IR): (PType, (Int, Region) => (AsmFunction5[Region, T0, Boolean, T1, Boolean, R] with FunctionWithAggRegion)) = { diff --git a/hail/src/main/scala/is/hail/expr/ir/IR.scala b/hail/src/main/scala/is/hail/expr/ir/IR.scala index 1f736f5d6a6..c0e89d7a437 100644 --- a/hail/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/IR.scala @@ -243,6 +243,7 @@ final case class ArrayAggScan(a: IR, name: String, query: IR) extends IR trait InferredPhysicalAggSignature { // will be filled in by InferPType in subsequent PR def signature: IndexedSeq[AggStateSignature] + var physicalSignatures2: Array[AggStatePhysicalSignature] = _ val physicalSignatures: Array[AggStatePhysicalSignature] = signature.map(_.toCanonicalPhysical).toArray } final case class RunAgg(body: IR, result: IR, signature: IndexedSeq[AggStateSignature]) extends IR with InferredPhysicalAggSignature diff --git a/hail/src/main/scala/is/hail/expr/ir/InferPType.scala b/hail/src/main/scala/is/hail/expr/ir/InferPType.scala index 3682e75dda3..1ea7aa705c4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InferPType.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InferPType.scala @@ -5,6 +5,66 @@ import is.hail.expr.types.virtual.{TNDArray, TVoid} import is.hail.utils._ object InferPType { + + def clearPTypes(x: IR): Unit = { + x._pType2 = null + x.children.foreach { c => clearPTypes(c.asInstanceOf[IR]) } + } + + // does not unify physical arg types if multiple nested seq/init ops appear; instead takes the first. The emitter checks equality. + def computePhysicalAgg(virt: AggStateSignature, initsAB: ArrayBuilder[RecursiveArrayBuilderElement[InitOp]], + seqAB: ArrayBuilder[RecursiveArrayBuilderElement[SeqOp]]): AggStatePhysicalSignature = { + val inits = initsAB.result() + val seqs = seqAB.result() + assert(inits.nonEmpty) + assert(seqs.nonEmpty) + virt.default match { + case AggElementsLengthCheck() => + + val iHead = inits.find(_.value.op == AggElementsLengthCheck()).get + val iNested = iHead.nested.get + val iHeadArgTypes = iHead.value.args.map(_.pType2) + + val sLCHead = seqs.find(_.value.op == AggElementsLengthCheck()).get + val sLCArgTypes = sLCHead.value.args.map(_.pType2) + val sAEHead = seqs.find(_.value.op == AggElements()).get + val sNested = sAEHead.nested.get + val sHeadArgTypes = sAEHead.value.args.map(_.pType2) + + val vNested = virt.nested.get.toArray + + val nested = vNested.indices.map { i => computePhysicalAgg(vNested(i), iNested(i), sNested(i)) } + AggStatePhysicalSignature(Map( + AggElementsLengthCheck() -> PhysicalAggSignature(AggElementsLengthCheck(), iHeadArgTypes, sLCArgTypes), + AggElements() -> PhysicalAggSignature(AggElements(), FastIndexedSeq(), sHeadArgTypes) + ), AggElementsLengthCheck(), Some(nested)) + + case Group() => + val iHead = inits.head + val iNested = iHead.nested.get + val iHeadArgTypes = iHead.value.args.map(_.pType2) + + val sHead = seqs.head + val sNested = sHead.nested.get + val sHeadArgTypes = sHead.value.args.map(_.pType2) + + val vNested = virt.nested.get.toArray + + val nested = vNested.indices.map { i => computePhysicalAgg(vNested(i), iNested(i), sNested(i)) } + val psig = PhysicalAggSignature(Group(), iHeadArgTypes, sHeadArgTypes) + AggStatePhysicalSignature(Map(Group() -> psig), Group(), Some(nested)) + + case _ => + assert(inits.forall(_.nested.isEmpty)) + assert(seqs.forall(_.nested.isEmpty)) + val iHead = inits.head.value + val iHeadArgTypes = iHead.args.map(_.pType2) + val sHead = seqs.head.value + val sHeadArgTypes = sHead.args.map(_.pType2) + virt.defaultSignature.toPhysical(iHeadArgTypes, sHeadArgTypes).singletonContainer + } + } + def getNestedElementPTypes(ptypes: Seq[PType]): PType = { assert(ptypes.forall(_.virtualType.isOfType(ptypes.head.virtualType))) getNestedElementPTypesOfSameType(ptypes: Seq[PType]) @@ -40,10 +100,20 @@ object InferPType { } } - def apply(ir: IR, env: Env[PType]): Unit = { - assert(ir._pType2 == null) + def apply(ir: IR, env: Env[PType]): Unit = apply(ir, env, null, null, null) + + private type AAB[T] = Array[ArrayBuilder[RecursiveArrayBuilderElement[T]]] + + case class RecursiveArrayBuilderElement[T](value: T, nested: Option[AAB[T]]) + + def newBuilder[T](n: Int): AAB[T] = Array.fill(n)(new ArrayBuilder[RecursiveArrayBuilderElement[T]]) - def infer(ir: IR, env: Env[PType] = env): Unit = apply(ir, env) + def apply(ir: IR, env: Env[PType], aggs: Array[AggStatePhysicalSignature], inits: AAB[InitOp], seqs: AAB[SeqOp]): Unit = { + if (ir._pType2 != null) + throw new RuntimeException(ir.toString) + + def infer(ir: IR, env: Env[PType] = env, aggs: Array[AggStatePhysicalSignature] = aggs, + inits: AAB[InitOp] = inits, seqs: AAB[SeqOp] = seqs): Unit = apply(ir, env, aggs, inits, seqs) ir._pType2 = ir match { case I32(_) => PInt32(true) @@ -206,6 +276,11 @@ object InferPType { assert(body.pType2 isOfType zero.pType2) zero.pType2.setRequired(body.pType2.required) + case ArrayFor(a, value, body) => + infer(a) + + infer(body, env.bind(value -> a.pType2.asInstanceOf[PStream].elementType)) + PVoid case ArrayFold2(a, acc, valueName, seq, res) => infer(a) acc.foreach { case (_, accIR) => infer(accIR) } @@ -226,7 +301,7 @@ object InferPType { case ArrayLeftJoinDistinct(lIR, rIR, lName, rName, compare, join) => infer(lIR) infer(rIR) - val e = env.bind(lName -> lIR.pType2.asInstanceOf[PStream].elementType, rName -> rIR.pType2.asInstanceOf[PStream].elementType) + val e = env.bind(lName -> lIR.pType2.asInstanceOf[PStream].elementType, rName -> rIR.pType2.asInstanceOf[PStream].elementType.setRequired(false)) infer(compare, e) infer(join, e) @@ -364,13 +439,6 @@ object InferPType { theIR._pType2 })) case In(_, pType: PType) => pType - case ArrayFor(a, valueName, body) => - infer(a) - infer(body, env.bind(valueName -> a._pType2.asInstanceOf[PStream].elementType)) - PVoid - case x if x.typ == TVoid => - x.children.foreach(c => infer(c.asInstanceOf[IR])) - PVoid case CollectDistributedArray(contextsIR, globalsIR, contextsName, globalsName, bodyIR) => infer(contextsIR) infer(globalsIR) @@ -402,10 +470,105 @@ object InferPType { val allReq = rPTypes.forall(f => f.typ.required) PCanonicalTuple(rPTypes, allReq) - case _: AggLet | _: RunAgg | _: RunAggScan | _: NDArrayAgg | _: AggFilter | _: AggExplode | - _: AggGroupBy | _: AggArrayPerElement | _: ApplyAggOp | _: ApplyScanOp | _: AggStateValue => PType.canonical(ir.typ) + case x@InitOp(i, args, sig, op) => + op match { + case Group() => + val nested = sig.nested.get + val newInits = newBuilder[InitOp](nested.length) + val IndexedSeq(initArg) = args + infer(initArg, env, null, inits = newInits, seqs = null) + if (inits != null) + inits(i) += RecursiveArrayBuilderElement(x, Some(newInits)) + case AggElementsLengthCheck() => + val nested = sig.nested.get + val newInits = newBuilder[InitOp](nested.length) + val initArg = args match { + case Seq(len, initArg) => + infer(len, env, null, null, null) + initArg + case Seq(initArg) => initArg + } + infer(initArg, env, null, inits = newInits, seqs = null) + if (inits != null) + inits(i) += RecursiveArrayBuilderElement(x, Some(newInits)) + case _ => + assert(sig.nested.isEmpty) + args.foreach(infer(_, env, null, null, null)) + if (inits != null) + inits(i) += RecursiveArrayBuilderElement(x, None) + } + PVoid + + case x@SeqOp(i, args, sig, op) => + op match { + case Group() => + val nested = sig.nested.get + val newSeqs = newBuilder[SeqOp](nested.length) + val IndexedSeq(k, seqArg) = args + infer(k, env, null, inits = null, seqs = null) + infer(seqArg, env, null, inits = null, seqs = newSeqs) + if (seqs != null) + seqs(i) += RecursiveArrayBuilderElement(x, Some(newSeqs)) + case AggElements() => + val nested = sig.nested.get + val newSeqs = newBuilder[SeqOp](nested.length) + val IndexedSeq(idx, seqArg) = args + infer(idx, env, null, inits = null, seqs = null) + infer(seqArg, env, null, inits = null, seqs = newSeqs) + if (seqs != null) + seqs(i) += RecursiveArrayBuilderElement(x, Some(newSeqs)) + case AggElementsLengthCheck() => + val nested = sig.nested.get + val IndexedSeq(idx) = args + infer(idx, env, null, inits = null, seqs = null) + if (seqs != null) + seqs(i) += RecursiveArrayBuilderElement(x, None) + case _ => + assert(sig.nested.isEmpty) + args.foreach(infer(_, env, null, null, null)) + if (seqs != null) + seqs(i) += RecursiveArrayBuilderElement(x, None) + } + PVoid + + case x@ResultOp(resultIdx, sigs) => + PCanonicalTuple(true, (resultIdx until resultIdx + sigs.length).map(i => aggs(i).resultType): _*) + + case x@RunAgg(body, result, signature) => + val inits = newBuilder[InitOp](signature.length) + val seqs = newBuilder[SeqOp](signature.length) + infer(body, env, inits = inits, seqs = seqs, aggs = null) + val sigs = signature.indices.map { i => computePhysicalAgg(signature(i), inits(i), seqs(i)) }.toArray + infer(result, env, aggs = sigs, inits = null, seqs = null) + x.physicalSignatures2 = sigs + result.pType2 + + case x@RunAggScan(array, name, init, seq, result, signature) => + infer(array) + val e2 = env.bind(name, coerce[PStreamable](array.pType2).elementType) + val inits = newBuilder[InitOp](signature.length) + val seqs = newBuilder[SeqOp](signature.length) + infer(init, env = e2, inits = inits, seqs = null, aggs = null) + infer(seq, env = e2, inits = null, seqs = seqs, aggs = null) + val sigs = signature.indices.map { i => computePhysicalAgg(signature(i), inits(i), seqs(i)) }.toArray + infer(result, env = e2, aggs = sigs, inits = null, seqs = null) + x.physicalSignatures2 = sigs + PCanonicalArray(result.pType2, array._pType2.required) + + case AggStateValue(i, sig) => PCanonicalBinary(true) + case x if x.typ == TVoid => + x.children.foreach(c => infer(c.asInstanceOf[IR])) + PVoid + + case NDArrayAgg(nd, _) => + infer(nd) + PType.canonical(ir.typ) + case x if x.typ == TVoid => + x.children.foreach(c => infer(c.asInstanceOf[IR])) + PVoid } + // Allow only requiredeness to diverge assert(ir.pType2.virtualType isOfType ir.typ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index ff374ef57fe..4a8a3fd5143 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -613,23 +613,29 @@ object Interpret { } else { val spec = BufferSpec.defaultUncompressed + val physicalAggs = extracted.getPhysicalAggs( + ctx, + Env("global" -> value.globals.t), + Env("global" -> value.globals.t, "row" -> value.rvd.rowPType) + ) + val (_, initOp) = CompileWithAggregators2[Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", value.globals.t, extracted.init) val (_, partitionOpSeq) = CompileWithAggregators2[Long, Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", value.globals.t, "row", value.rvd.rowPType, extracted.seqPerElt) - val read = extracted.deserialize(ctx, spec) - val write = extracted.serialize(ctx, spec) - val combOpF = extracted.combOpF(ctx, spec) + val read = extracted.deserialize(ctx, spec, physicalAggs) + val write = extracted.serialize(ctx, spec, physicalAggs) + val combOpF = extracted.combOpF(ctx, spec, physicalAggs) val (rTyp: PTuple, f) = CompileWithAggregators2[Long, Long](ctx, - extracted.aggs, + physicalAggs, "global", value.globals.t, Let(res, extracted.results, MakeTuple.ordered(FastSeq(extracted.postAggIR)))) assert(rTyp.types(0).virtualType == query.typ) diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 6485ddb206f..eb3abdc9a55 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -1021,6 +1021,12 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { rvd = tv.rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key), itF)) } + val physicalAggs = extracted.getPhysicalAggs( + ctx, + Env("global" -> tv.globals.t), + Env("global" -> tv.globals.t, "row" -> tv.rvd.rowPType) + ) + val scanInitNeedsGlobals = Mentions(extracted.init, "global") val scanSeqNeedsGlobals = Mentions(extracted.seqPerElt, "global") val rowIterationNeedsGlobals = Mentions(extracted.postAggIR, "global") @@ -1040,22 +1046,22 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { // 4. load in partStarts, calculate newRow based on those results. val (_, initF) = ir.CompileWithAggregators2[Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", tv.globals.t, Begin(FastIndexedSeq(extracted.init, extracted.serializeSet(0, 0, spec)))) val (_, eltSeqF) = ir.CompileWithAggregators2[Long, Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", Option(globalsBc).map(_.value.t).getOrElse(PStruct()), "row", tv.rvd.rowPType, extracted.eltOp(ctx)) - val read = extracted.deserialize(ctx, spec) - val write = extracted.serialize(ctx, spec) - val combOpF = extracted.combOpF(ctx, spec) + val read = extracted.deserialize(ctx, spec, physicalAggs) + val write = extracted.serialize(ctx, spec, physicalAggs) + val combOpF = extracted.combOpF(ctx, spec, physicalAggs) val (rTyp, f) = ir.CompileWithAggregators2[Long, Long, Long](ctx, - extracted.aggs, + physicalAggs, "global", Option(globalsBc).map(_.value.t).getOrElse(PStruct()), "row", tv.rvd.rowPType, Let(scanRef, extracted.results, extracted.postAggIR)) @@ -1387,26 +1393,32 @@ case class TableKeyByAndAggregate( val res = genUID() val extracted = agg.Extract(expr, res) + val physicalAggs = extracted.getPhysicalAggs( + ctx, + Env("global" -> prev.globals.t), + Env("global" -> prev.globals.t, "row" -> prev.rvd.rowPType) + ) + val (_, makeInit) = ir.CompileWithAggregators2[Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", prev.globals.t, extracted.init) val (_, makeSeq) = ir.CompileWithAggregators2[Long, Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", prev.globals.t, "row", prev.rvd.rowPType, extracted.seqPerElt) val (rTyp: PStruct, makeAnnotate) = ir.CompileWithAggregators2[Long, Long](ctx, - extracted.aggs, + physicalAggs, "global", prev.globals.t, Let(res, extracted.results, extracted.postAggIR)) assert(rTyp.virtualType == typ.valueType, s"$rTyp, ${ typ.valueType }") - val serialize = extracted.serialize(ctx, spec) - val deserialize = extracted.deserialize(ctx, spec) - val combOp = extracted.combOpF(ctx, spec) + val serialize = extracted.serialize(ctx, spec, physicalAggs) + val deserialize = extracted.deserialize(ctx, spec, physicalAggs) + val combOp = extracted.combOpF(ctx, spec, physicalAggs) val initF = makeInit(0, ctx.r) val globalsOffset = prev.globals.value.offset @@ -1517,13 +1529,19 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val res = genUID() val extracted = agg.Extract(expr, res) + val physicalAggs = extracted.getPhysicalAggs( + ctx, + Env("global" -> prev.globals.t), + Env("global" -> prev.globals.t, "row" -> prev.rvd.rowPType) + ) + val (_, makeInit) = ir.CompileWithAggregators2[Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", prev.globals.t, extracted.init) val (_, makeSeq) = ir.CompileWithAggregators2[Long, Long, Unit](ctx, - extracted.aggs, + physicalAggs, "global", prev.globals.t, "row", prev.rvd.rowPType, extracted.seqPerElt) @@ -1533,7 +1551,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val key = Ref(genUID(), keyType.virtualType) val value = Ref(genUID(), valueIR.typ) val (rowType: PStruct, makeRow) = ir.CompileWithAggregators2[Long, Long, Long](ctx, - extracted.aggs, + physicalAggs, "global", prev.globals.t, key.name, keyType, Let(value.name, valueIR, diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala index 4fd41980191..94247216ebc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala @@ -51,7 +51,6 @@ class DownsampleState(val fb: EmitFunctionBuilder[_], labelType: PArray, maxBuff def createState: Code[Unit] = region.isNull.mux(r := Region.stagedCreate(regionSize), Code._empty) - val binType = PStruct(required = true, "x" -> PInt32Required, "y" -> PInt32Required) val pointType = PStruct(required = true, "x" -> PFloat64Required, "y" -> PFloat64Required, "label" -> labelType) diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala index 999060b5b03..6ace16a238b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala @@ -4,6 +4,7 @@ import is.hail.HailContext import is.hail.annotations.{Region, RegionValue} import is.hail.expr.ir import is.hail.expr.ir._ +import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.types.physical._ import is.hail.expr.types.virtual._ import is.hail.io.BufferSpec @@ -43,9 +44,9 @@ case class Aggs(postAggIR: IR, init: IR, seqPerElt: IR, aggs: Array[AggStateSign def eltOp(ctx: ExecuteContext): IR = seqPerElt - def deserialize(ctx: ExecuteContext, spec: BufferSpec): ((Region, Array[Byte]) => Long) = { + def deserialize(ctx: ExecuteContext, spec: BufferSpec, physicalAggs: Array[AggStatePhysicalSignature]): ((Region, Array[Byte]) => Long) = { val (_, f) = ir.CompileWithAggregators2[Unit](ctx, - aggs, ir.DeserializeAggs(0, 0, spec, aggs)) + physicalAggs, ir.DeserializeAggs(0, 0, spec, aggs)) { (aggRegion: Region, bytes: Array[Byte]) => val f2 = f(0, aggRegion); @@ -56,9 +57,9 @@ case class Aggs(postAggIR: IR, init: IR, seqPerElt: IR, aggs: Array[AggStateSign } } - def serialize(ctx: ExecuteContext, spec: BufferSpec): (Region, Long) => Array[Byte] = { + def serialize(ctx: ExecuteContext, spec: BufferSpec, physicalAggs: Array[AggStatePhysicalSignature]): (Region, Long) => Array[Byte] = { val (_, f) = ir.CompileWithAggregators2[Unit](ctx, - aggs, ir.SerializeAggs(0, 0, spec, aggs)) + physicalAggs, ir.SerializeAggs(0, 0, spec, aggs)) { (aggRegion: Region, off: Long) => val f2 = f(0, aggRegion); @@ -68,9 +69,9 @@ case class Aggs(postAggIR: IR, init: IR, seqPerElt: IR, aggs: Array[AggStateSign } } - def combOpF(ctx: ExecuteContext, spec: BufferSpec): (Array[Byte], Array[Byte]) => Array[Byte] = { + def combOpF(ctx: ExecuteContext, spec: BufferSpec, physicalAggs: Array[AggStatePhysicalSignature]): (Array[Byte], Array[Byte]) => Array[Byte] = { val (_, f) = ir.CompileWithAggregators2[Unit](ctx, - aggs ++ aggs, + physicalAggs ++ physicalAggs, Begin( deserializeSet(0, 0, spec) +: deserializeSet(1, 1, spec) +: @@ -90,6 +91,25 @@ case class Aggs(postAggIR: IR, init: IR, seqPerElt: IR, aggs: Array[AggStateSign } def results: IR = ResultOp(0, aggs) + + def getPhysicalAggs(ctx: ExecuteContext, initBindings: Env[PType], seqBindings: Env[PType]): Array[AggStatePhysicalSignature] = { + val initsAB = InferPType.newBuilder[InitOp](aggs.length) + val seqsAB = InferPType.newBuilder[SeqOp](aggs.length) + val init2 = LoweringPipeline.compileLowerer.apply(ctx, init, false).asInstanceOf[IR] + val seq2 = LoweringPipeline.compileLowerer.apply(ctx, seqPerElt, false).asInstanceOf[IR] + InferPType(init2.noSharing, initBindings, null, inits = initsAB, null) + InferPType(seq2.noSharing, seqBindings, null, null, seqs = seqsAB) + + val pSigs = aggs.indices.map { i => InferPType.computePhysicalAgg(aggs(i), initsAB(i), seqsAB(i)) }.toArray + + if (init2 eq init) + InferPType.clearPTypes(init2) + if (seq2 eq seqPerElt) + InferPType.clearPTypes(seq2) + + // should return pSigs, but cannot until we use the inferred ptype to generate code + aggs.map(_.toCanonicalPhysical) + } } object Extract { diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala index ad274ff249c..491ab934a99 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala @@ -14,6 +14,7 @@ object StagedArrayBuilder { class StagedArrayBuilder(eltType: PType, fb: EmitFunctionBuilder[_], region: Code[Region], var initialCapacity: Int = 8) { val eltArray = PArray(eltType.setRequired(false), required = true) // element type must be optional for serialization to work val stateType = PTuple(true, PInt32Required, PInt32Required, eltArray) + val size: ClassFieldRef[Int] = fb.newField[Int]("size") private val capacity = fb.newField[Int]("capacity") val data = fb.newField[Long]("data") diff --git a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala index 6ece000a2cf..d00a6a122ef 100644 --- a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala @@ -32,8 +32,9 @@ class Aggregators2Suite extends HailSuite { val argRef = Ref(genUID(), argT.virtualType) val spec = BufferSpec.defaultUncompressed + val psig = aggSig.toCanonicalPhysical val (_, combAndDuplicate) = CompileWithAggregators2[Unit](ctx, - Array.fill(nPartitions)(aggSig), + Array.fill(nPartitions)(psig), Begin( Array.tabulate(nPartitions)(i => DeserializeAggs(i, i, spec, Array(aggSig))) ++ Array.range(1, nPartitions).map(i => CombOp(0, i, aggSig)) :+ @@ -41,7 +42,7 @@ class Aggregators2Suite extends HailSuite { DeserializeAggs(1, 0, spec, Array(aggSig)))) val (rt: PTuple, resF) = CompileWithAggregators2[Long](ctx, - Array.fill(nPartitions)(aggSig), + Array.fill(nPartitions)(psig), ResultOp(0, Array(aggSig, aggSig))) assert(rt.types(0) == rt.types(1)) @@ -54,7 +55,7 @@ class Aggregators2Suite extends HailSuite { def withArgs(foo: IR) = { CompileWithAggregators2[Long, Unit](ctx, - Array(aggSig), + Array(psig), argRef.name, argRef.pType, args.map(_._1).foldLeft[IR](foo) { case (op, name) => Let(name, GetField(argRef, name), op) @@ -63,14 +64,14 @@ class Aggregators2Suite extends HailSuite { val serialize = SerializeAggs(0, 0, spec, Array(aggSig)) val (_, writeF) = CompileWithAggregators2[Unit](ctx, - Array(aggSig), + Array(psig), serialize) val initF = withArgs(initOp) expectedInit.foreach { v => val (rt: PBaseStruct, resOneF) = CompileWithAggregators2[Long](ctx, - Array(aggSig), ResultOp(0, Array(aggSig))) + Array(psig), ResultOp(0, Array(aggSig))) val init = initF(0, region) val res = resOneF(0, region) diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index dc72ecf8cb8..59d933ca21e 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -339,7 +339,7 @@ class IRSuite extends HailSuite { // should not be able to infer physical type twice on one IR (i32na) node = ApplyUnaryPrimOp(Negate(), i32na) - intercept[AssertionError](InferPType(node, Env.empty)) + intercept[RuntimeException](InferPType(node, Env.empty)) node = ApplyUnaryPrimOp(Negate(), I64(5)) assertPType(node, PInt64(true))