From c60079506d200cf78c2f7ccec8bfd5f709ab19f9 Mon Sep 17 00:00:00 2001 From: Ankita Victor-Levi Date: Fri, 20 Feb 2026 20:10:35 +0530 Subject: [PATCH 1/4] Update rules --- .../velox/VeloxSparkPlanExecApi.scala | 17 ++++++++++ .../HashAggregateExecTransformer.scala | 33 +++++++++++++++++++ .../FlushableHashAggregateRule.scala | 12 ++++++- .../VeloxAggregateFunctionsSuite.scala | 31 +++++++++++++++++ .../gluten/backendsapi/SparkPlanExecApi.scala | 23 +++++++++++++ .../HashAggregateExecBaseTransformer.scala | 20 +++++++++++ .../columnar/EliminateLocalSort.scala | 4 +-- .../offload/OffloadSingleNodeRules.scala | 2 +- 8 files changed, 138 insertions(+), 4 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 69419deb1a2a..45aaa3707350 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -349,6 +349,23 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { resultExpressions, child) + override def genSortAggregateExecTransformer( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan): HashAggregateExecBaseTransformer = + SortHashAggregateExecTransformer( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) + /** Generate HashAggregateExecPullOutHelper */ override def genHashAggregateExecPullOutHelper( aggregateExpressions: Seq[AggregateExpression], diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala index 8097ea925ddb..a0c52c790907 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala @@ -623,6 +623,39 @@ case class RegularHashAggregateExecTransformer( } } +// Hash aggregation that was offloaded from a SortAggregateExec. Preserves sort-aggregate semantics +// so that upstream sort elimination rules can safely remove the preceding sort. +case class SortHashAggregateExecTransformer( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends HashAggregateExecTransformer( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) + with SortAggregateExecTransformer { + + override protected def allowFlush: Boolean = false + + override def simpleString(maxFields: Int): String = + s"SortToHash${super.simpleString(maxFields)}" + + override def verboseString(maxFields: Int): String = + s"SortToHash${super.verboseString(maxFields)}" + + override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = { + copy(child = newChild) + } +} + // Hash aggregation that emits pre-aggregated data which allows duplications on grouping keys // among its output rows. case class FlushableHashAggregateExecTransformer( diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 0aa48d8d3770..7002a2bd93d0 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -83,7 +83,7 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP } private def replaceEligibleAggregates(plan: SparkPlan)( - func: RegularHashAggregateExecTransformer => SparkPlan): SparkPlan = { + func: HashAggregateExecTransformer => SparkPlan): SparkPlan = { def transformDown: SparkPlan => SparkPlan = { case agg: RegularHashAggregateExecTransformer if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) => @@ -98,6 +98,16 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP agg case agg: RegularHashAggregateExecTransformer => func(agg) + case agg: SortHashAggregateExecTransformer + if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) => + agg + case agg: SortHashAggregateExecTransformer if isAggInputAlreadyDistributedWithAggKeys(agg) => + agg + case agg: SortHashAggregateExecTransformer + if aggregatesNotSupportFlush(agg.aggregateExpressions) => + agg + case agg: SortHashAggregateExecTransformer => + func(agg) case p if !canPropagate(p) => p case other => other.withNewChildren(other.children.map(transformDown)) } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala index ca259e0faeb7..4e0bc07784ac 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala @@ -1171,6 +1171,37 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu } } } + + test("test collect_list with ordering") { + withTempView("t1") { + Seq((2, "d"), (2, "e"), (2, "f"), (1, "b"), (1, "a"), (1, "c"), (3, "i"), (3, "h"), (3, "g")) + .toDF("id", "value") + .createOrReplaceTempView("t1") + runQueryAndCompare( + """ + | SELECT 1 - id, collect_list(value) AS values_list + | FROM ( + | select * from + | (SELECT id, value + | FROM t1 + | DISTRIBUTE BY rand()) + | DISTRIBUTE BY id sort by id,value + | ) t + | GROUP BY 1 + |""".stripMargin, + false + ) { + df => + { + assert( + getExecutedPlan(df).count( + plan => { + plan.isInstanceOf[SortHashAggregateExecTransformer] + }) == 2) + } + } + } + } } class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 002c8dad7b1b..27dbb6ce4c95 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -81,6 +81,29 @@ trait SparkPlanExecApi { resultExpressions: Seq[NamedExpression], child: SparkPlan): HashAggregateExecBaseTransformer + /** + * Generate a HashAggregateExecTransformer for a SortAggregateExec that is being offloaded to a + * native hash aggregate. The returned transformer preserves sort-aggregate semantics (e.g., + * requiredChildOrdering) so that upstream sort elimination rules can distinguish it from a + * regular hash aggregate. + */ + def genSortAggregateExecTransformer( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan): HashAggregateExecBaseTransformer = + genHashAggregateExecTransformer( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) + /** Generate HashAggregateExecPullOutHelper */ def genHashAggregateExecPullOutHelper( aggregateExpressions: Seq[AggregateExpression], diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala index a4bcc6081e43..f4e174d9f509 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala @@ -196,8 +196,28 @@ object HashAggregateExecBaseTransformer { agg.child ) } + + def fromSortAggregate(agg: BaseAggregateExec): HashAggregateExecBaseTransformer = { + BackendsApiManager.getSparkPlanExecApiInstance + .genSortAggregateExecTransformer( + agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + getInitialInputBufferOffset(agg), + agg.resultExpressions, + agg.child + ) + } } +/** + * Marker trait for hash aggregate transformers that were offloaded from a SortAggregateExec. This + * allows sort elimination rules to distinguish aggregates that were originally sort-based (and thus + * can safely eliminate their upstream sort) from regular hash aggregates (which must not). + */ +trait SortAggregateExecTransformer extends HashAggregateExecBaseTransformer {} + trait HashAggregateExecPullOutBaseHelper { // The direct outputs of Aggregation. def allAggregateResultAttributes(groupingExpressions: Seq[NamedExpression]): List[Attribute] = diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala index 8a2c731e5e03..17e7f29eec47 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EliminateLocalSort.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.extension.columnar -import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ProjectExecTransformer, ShuffledHashJoinExecTransformerBase, SortExecTransformer, WindowGroupLimitExecTransformer} +import org.apache.gluten.execution.{ProjectExecTransformer, ShuffledHashJoinExecTransformerBase, SortAggregateExecTransformer, SortExecTransformer, WindowGroupLimitExecTransformer} import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.rules.Rule @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, UnaryEx */ object EliminateLocalSort extends Rule[SparkPlan] { private def canEliminateLocalSort(p: SparkPlan): Boolean = p match { - case _: HashAggregateExecBaseTransformer => true + case _: SortAggregateExecTransformer => true case _: ShuffledHashJoinExecTransformerBase => true case _: WindowGroupLimitExecTransformer => true case s: SortExec if s.global == false => true diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala index 684fbd36f1ac..3d844607b391 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNodeRules.scala @@ -215,7 +215,7 @@ object OffloadOthers { case plan: HashAggregateExec => HashAggregateExecBaseTransformer.from(plan) case plan: SortAggregateExec => - HashAggregateExecBaseTransformer.from(plan) + HashAggregateExecBaseTransformer.fromSortAggregate(plan) case plan: ObjectHashAggregateExec => HashAggregateExecBaseTransformer.from(plan) case plan: UnionExec => From 851c25ee829fd2154ad6b5e0b06b8ef4f919fa15 Mon Sep 17 00:00:00 2001 From: Ankita Victor-Levi Date: Tue, 3 Mar 2026 15:19:30 +0530 Subject: [PATCH 2/4] Add comment --- .../gluten/extension/FlushableHashAggregateRule.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 7002a2bd93d0..645cb056042a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -82,6 +82,12 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP aggExprs.exists(isUnsupportedAggregation) } + /** + * Walks the plan downward, applying func to each RegularHashAggregateExecTransformer or + * SortHashAggregateExecTransformer that is eligible for flushable conversion. An aggregate + * is eligible when all expressions are Partial/PartialMerge, input is not already + * partitioned by the grouping keys, and no aggregate function disallows flushing. + */ private def replaceEligibleAggregates(plan: SparkPlan)( func: HashAggregateExecTransformer => SparkPlan): SparkPlan = { def transformDown: SparkPlan => SparkPlan = { From f120a594f185325faf230b1e1ef9490e095ba2f6 Mon Sep 17 00:00:00 2001 From: Ankita Victor-Levi Date: Tue, 3 Mar 2026 15:25:43 +0530 Subject: [PATCH 3/4] Add more comments --- .../gluten/extension/FlushableHashAggregateRule.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 645cb056042a..1ff2c72238b1 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -93,26 +93,32 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP def transformDown: SparkPlan => SparkPlan = { case agg: RegularHashAggregateExecTransformer if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) => - // Not a intermediate agg. Skip. + // Not an intermediate agg. Skip. agg case agg: RegularHashAggregateExecTransformer if isAggInputAlreadyDistributedWithAggKeys(agg) => - // Data already grouped by aggregate keys, Skip. + // Data already grouped by aggregate keys. Skip. agg case agg: RegularHashAggregateExecTransformer if aggregatesNotSupportFlush(agg.aggregateExpressions) => + // Aggregate uses a function that is unsafe to flush. Skip. agg case agg: RegularHashAggregateExecTransformer => + // All guards passed; replace with the flushable variant. func(agg) case agg: SortHashAggregateExecTransformer if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) => + // Not an intermediate agg. Skip. agg case agg: SortHashAggregateExecTransformer if isAggInputAlreadyDistributedWithAggKeys(agg) => + // Data already grouped by aggregate keys. Skip. agg case agg: SortHashAggregateExecTransformer if aggregatesNotSupportFlush(agg.aggregateExpressions) => + // Aggregate uses a function that is unsafe to flush. Skip. agg case agg: SortHashAggregateExecTransformer => + // All guards passed; replace with the flushable variant. func(agg) case p if !canPropagate(p) => p case other => other.withNewChildren(other.children.map(transformDown)) From b3fdd34503b2129bda3f4629746eaf3a6af79acc Mon Sep 17 00:00:00 2001 From: Ankita Victor-Levi Date: Tue, 3 Mar 2026 19:30:49 +0530 Subject: [PATCH 4/4] Fix format --- .../gluten/extension/FlushableHashAggregateRule.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index 1ff2c72238b1..6216dd8747f1 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -84,9 +84,9 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP /** * Walks the plan downward, applying func to each RegularHashAggregateExecTransformer or - * SortHashAggregateExecTransformer that is eligible for flushable conversion. An aggregate - * is eligible when all expressions are Partial/PartialMerge, input is not already - * partitioned by the grouping keys, and no aggregate function disallows flushing. + * SortHashAggregateExecTransformer that is eligible for flushable conversion. An aggregate is + * eligible when all expressions are Partial/PartialMerge, input is not already partitioned by the + * grouping keys, and no aggregate function disallows flushing. */ private def replaceEligibleAggregates(plan: SparkPlan)( func: HashAggregateExecTransformer => SparkPlan): SparkPlan = {