Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,23 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
resultExpressions,
child)

override def genSortAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer =
SortHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child)

/** Generate HashAggregateExecPullOutHelper */
override def genHashAggregateExecPullOutHelper(
aggregateExpressions: Seq[AggregateExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,39 @@ case class RegularHashAggregateExecTransformer(
}
}

// Hash aggregation that was offloaded from a SortAggregateExec. Preserves sort-aggregate semantics
// so that upstream sort elimination rules can safely remove the preceding sort.
case class SortHashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child)
with SortAggregateExecTransformer {

override protected def allowFlush: Boolean = false

override def simpleString(maxFields: Int): String =
s"SortToHash${super.simpleString(maxFields)}"

override def verboseString(maxFields: Int): String =
s"SortToHash${super.verboseString(maxFields)}"

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = {
copy(child = newChild)
}
}

// Hash aggregation that emits pre-aggregated data which allows duplications on grouping keys
// among its output rows.
case class FlushableHashAggregateExecTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,43 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP
aggExprs.exists(isUnsupportedAggregation)
}

/**
* Walks the plan downward, applying func to each RegularHashAggregateExecTransformer or
* SortHashAggregateExecTransformer that is eligible for flushable conversion. An aggregate is
* eligible when all expressions are Partial/PartialMerge, input is not already partitioned by the
* grouping keys, and no aggregate function disallows flushing.
*/
private def replaceEligibleAggregates(plan: SparkPlan)(
func: RegularHashAggregateExecTransformer => SparkPlan): SparkPlan = {
func: HashAggregateExecTransformer => SparkPlan): SparkPlan = {
def transformDown: SparkPlan => SparkPlan = {
case agg: RegularHashAggregateExecTransformer
if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) =>
// Not a intermediate agg. Skip.
// Not an intermediate agg. Skip.
agg
case agg: RegularHashAggregateExecTransformer
if isAggInputAlreadyDistributedWithAggKeys(agg) =>
// Data already grouped by aggregate keys, Skip.
// Data already grouped by aggregate keys. Skip.
agg
case agg: RegularHashAggregateExecTransformer
if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
// Aggregate uses a function that is unsafe to flush. Skip.
agg
case agg: RegularHashAggregateExecTransformer =>
// All guards passed; replace with the flushable variant.
func(agg)
case agg: SortHashAggregateExecTransformer
if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) =>
// Not an intermediate agg. Skip.
agg
case agg: SortHashAggregateExecTransformer if isAggInputAlreadyDistributedWithAggKeys(agg) =>
// Data already grouped by aggregate keys. Skip.
agg
case agg: SortHashAggregateExecTransformer
if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
// Aggregate uses a function that is unsafe to flush. Skip.
agg
case agg: SortHashAggregateExecTransformer =>
// All guards passed; replace with the flushable variant.
func(agg)
case p if !canPropagate(p) => p
case other => other.withNewChildren(other.children.map(transformDown))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,37 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}
}

test("test collect_list with ordering") {
withTempView("t1") {
Seq((2, "d"), (2, "e"), (2, "f"), (1, "b"), (1, "a"), (1, "c"), (3, "i"), (3, "h"), (3, "g"))
.toDF("id", "value")
.createOrReplaceTempView("t1")
runQueryAndCompare(
"""
| SELECT 1 - id, collect_list(value) AS values_list
| FROM (
| select * from
| (SELECT id, value
| FROM t1
| DISTRIBUTE BY rand())
| DISTRIBUTE BY id sort by id,value
| ) t
| GROUP BY 1
|""".stripMargin,
false
) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[SortHashAggregateExecTransformer]
}) == 2)
}
}
}
}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ trait SparkPlanExecApi {
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer

/**
* Generate a HashAggregateExecTransformer for a SortAggregateExec that is being offloaded to a
* native hash aggregate. The returned transformer preserves sort-aggregate semantics (e.g.,
* requiredChildOrdering) so that upstream sort elimination rules can distinguish it from a
* regular hash aggregate.
*/
def genSortAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer =
genHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child)

/** Generate HashAggregateExecPullOutHelper */
def genHashAggregateExecPullOutHelper(
aggregateExpressions: Seq[AggregateExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,28 @@ object HashAggregateExecBaseTransformer {
agg.child
)
}

def fromSortAggregate(agg: BaseAggregateExec): HashAggregateExecBaseTransformer = {
BackendsApiManager.getSparkPlanExecApiInstance
.genSortAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
agg.groupingExpressions,
agg.aggregateExpressions,
agg.aggregateAttributes,
getInitialInputBufferOffset(agg),
agg.resultExpressions,
agg.child
)
}
}

/**
* Marker trait for hash aggregate transformers that were offloaded from a SortAggregateExec. This
* allows sort elimination rules to distinguish aggregates that were originally sort-based (and thus
* can safely eliminate their upstream sort) from regular hash aggregates (which must not).
*/
trait SortAggregateExecTransformer extends HashAggregateExecBaseTransformer {}

trait HashAggregateExecPullOutBaseHelper {
// The direct outputs of Aggregation.
def allAggregateResultAttributes(groupingExpressions: Seq[NamedExpression]): List[Attribute] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar

import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ProjectExecTransformer, ShuffledHashJoinExecTransformerBase, SortExecTransformer, WindowGroupLimitExecTransformer}
import org.apache.gluten.execution.{ProjectExecTransformer, ShuffledHashJoinExecTransformerBase, SortAggregateExecTransformer, SortExecTransformer, WindowGroupLimitExecTransformer}

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, UnaryEx
*/
object EliminateLocalSort extends Rule[SparkPlan] {
private def canEliminateLocalSort(p: SparkPlan): Boolean = p match {
case _: HashAggregateExecBaseTransformer => true
case _: SortAggregateExecTransformer => true
case _: ShuffledHashJoinExecTransformerBase => true
case _: WindowGroupLimitExecTransformer => true
case s: SortExec if s.global == false => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ object OffloadOthers {
case plan: HashAggregateExec =>
HashAggregateExecBaseTransformer.from(plan)
case plan: SortAggregateExec =>
HashAggregateExecBaseTransformer.from(plan)
HashAggregateExecBaseTransformer.fromSortAggregate(plan)
case plan: ObjectHashAggregateExec =>
HashAggregateExecBaseTransformer.from(plan)
case plan: UnionExec =>
Expand Down
Loading