@@ -27,78 +27,126 @@ import org.apache.spark.sql.internal.SQLConf
2727import org .apache .spark .sql .types .DataType
2828
2929/**
30- * This rule tries to merge multiple non-correlated [[ScalarSubquery ]]s to compute multiple scalar
31- * values once.
30+ * This rule tries to merge multiple subplans that have one row result. This can be either the plan
31+ * tree of a [[ScalarSubquery ]] expression or the plan tree starting at a non-grouping [[Aggregate ]]
32+ * node.
3233 *
3334 * The process is the following:
34- * - While traversing through the plan each [[ ScalarSubquery ]] plan is tried to merge into already
35- * seen subquery plans using `PlanMerger`s.
35+ * - While traversing through the plan each one row returning subplan is tried to merge into already
36+ * seen one row returning subplans using `PlanMerger`s.
3637 * During this first traversal each [[ScalarSubquery ]] expression is replaced to a temporal
37- * [[ScalarSubqueryReference ]] pointing to its possible merged version stored in `PlanMerger`s.
38- * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more plans, or is an
39- * original unmerged plan. [[ScalarSubqueryReference ]]s contain all the required information to
40- * either restore the original [[ScalarSubquery ]] or create a reference to a merged CTE.
41- * - Once the first traversal is complete and all possible merging have been done a second traversal
42- * removes the [[ScalarSubqueryReference ]]s to either restore the original [[ScalarSubquery ]] or
43- * to replace the original to a modified one that references a CTE with a merged plan.
38+ * [[ScalarSubqueryReference ]] and each non-grouping [[Aggregate ]] node is replaced to a temporal
39+ * [[NonGroupingAggregateReference ]] pointing to its possible merged version in `PlanMerger`s.
40+ * `PlanMerger`s keep track of whether a plan is a result of merging 2 or more subplans, or is an
41+ * original unmerged plan.
42+ * [[ScalarSubqueryReference ]]s and [[NonGroupingAggregateReference ]]s contain all the required
43+ * information to either restore the original subplan or create a reference to a merged CTE.
44+ * - Once the first traversal is complete and all possible merging have been done, a second
45+ * traversal removes the references to either restore the original subplans or to replace the
46+ * original to a modified ones that reference a CTE with a merged plan.
4447 * A modified [[ScalarSubquery ]] is constructed like:
45- * `GetStructField(ScalarSubquery(CTERelationRef(...)), outputIndex)` where `outputIndex` is the
46- * index of the output attribute (of the CTE) that corresponds to the output of the original
47- * subquery.
48+ * `GetStructField(ScalarSubquery(CTERelationRef to the merged plan), merged output index)`
49+ * while a modified [[Aggregate ]] is constructed like:
50+ * ```
51+ * Project(
52+ * Seq(
53+ * GetStructField(
54+ * ScalarSubquery(CTERelationRef to the merged plan),
55+ * merged output index 1),
56+ * GetStructField(
57+ * ScalarSubquery(CTERelationRef to the merged plan),
58+ * merged output index 2),
59+ * ...),
60+ * OneRowRelation)
61+ * ```
62+ * where `merged output index`s are the index of the output attributes (of the CTE) that
63+ * correspond to the output of the original node.
4864 * - If there are merged subqueries in `PlanMerger`s then a `WithCTE` node is built from these
49- * queries. The `CTERelationDef` nodes contain the merged subquery in the following form:
50- * `Project(Seq(CreateNamedStruct(name1, attribute1, ...) AS mergedValue), mergedSubqueryPlan)`.
51- * The definitions are flagged that they host a subquery, that can return maximum one row.
65+ * queries. The `CTERelationDef` nodes contain the merged subplans in the following form:
66+ * `Project(Seq(CreateNamedStruct(name 1, attribute 1, ...) AS mergedValue), mergedSubplan)`.
5267 *
53- * Eg. the following query :
68+ * Here are a few examples :
5469 *
55- * SELECT
56- * (SELECT avg(a) FROM t),
57- * (SELECT sum(b) FROM t)
58- *
59- * is optimized from:
60- *
61- * == Optimized Logical Plan ==
62- * Project [scalar-subquery#242 [] AS scalarsubquery()#253,
63- * scalar-subquery#243 [] AS scalarsubquery()#254L]
64- * : :- Aggregate [avg(a#244) AS avg(a)#247]
65- * : : +- Project [a#244]
66- * : : +- Relation default.t[a#244,b#245] parquet
67- * : +- Aggregate [sum(a#251) AS sum(a)#250L]
68- * : +- Project [a#251]
69- * : +- Relation default.t[a#251,b#252] parquet
70+ * 1. a query with 2 subqueries:
71+ * ```
72+ * Project [scalar-subquery [] AS scalarsubquery(), scalar-subquery [] AS scalarsubquery()]
73+ * : :- Aggregate [min(a) AS min(a)]
74+ * : : +- Relation [a, b, c]
75+ * : +- Aggregate [sum(b) AS sum(b)]
76+ * : +- Relation [a, b, c]
7077 * +- OneRowRelation
78+ * ```
79+ * is optimized to:
80+ * ```
81+ * WithCTE
82+ * :- CTERelationDef 0
83+ * : +- Project [named_struct(min(a), min(a), sum(b), sum(b)) AS mergedValue]
84+ * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b)]
85+ * : +- Relation [a, b, c]
86+ * +- Project [scalar-subquery [].min(a) AS scalarsubquery(),
87+ * scalar-subquery [].sum(b) AS scalarsubquery()]
88+ * : :- CTERelationRef 0
89+ * : +- CTERelationRef 0
90+ * +- OneRowRelation
91+ * ```
7192 *
72- * to:
93+ * 2. a query with 2 non-grouping aggregates:
94+ * ```
95+ * Join Inner
96+ * :- Aggregate [min(a) AS min(a)]
97+ * : +- Relation [a, b, c]
98+ * +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
99+ * +- Relation [a, b, c]
100+ * ```
101+ * is optimized to:
102+ * ```
103+ * WithCTE
104+ * :- CTERelationDef 0
105+ * : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue]
106+ * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
107+ * : +- Relation [a, b, c]
108+ * +- Join Inner
109+ * :- Project [scalar-subquery [].min(a) AS min(a)]
110+ * : : +- CTERelationRef 0
111+ * : +- OneRowRelation
112+ * +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)]
113+ * : :- CTERelationRef 0
114+ * : +- CTERelationRef 0
115+ * +- OneRowRelation
116+ * ```
73117 *
74- * == Optimized Logical Plan ==
75- * Project [scalar-subquery#242 [].avg(a) AS scalarsubquery()#253,
76- * scalar-subquery#243 [].sum(a) AS scalarsubquery()#254L]
77- * : :- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
78- * : : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L]
79- * : : +- Project [a#244]
80- * : : +- Relation default.t[a#244,b#245] parquet
81- * : +- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
82- * : +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L]
83- * : +- Project [a#244]
84- * : +- Relation default.t[a#244,b#245] parquet
85- * +- OneRowRelation
118+ * 3. a query with a subquery and a non-grouping aggregate:
119+ * ```
120+ * Join Inner
121+ * :- Project [scalar-subquery [] AS scalarsubquery()]
122+ * : : +- Aggregate [min(a) AS min(a)]
123+ * : : +- Relation [a, b, c]
124+ * : +- OneRowRelation
125+ * +- Aggregate [sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
126+ * +- Relation [a, b, c]
127+ * ```
128+ * is optimized to:
129+ * ```
130+ * WithCTE
131+ * :- CTERelationDef 0
132+ * : +- Project [named_struct(min(a), min(a), sum(b), sum(b), avg(c), avg(c)) AS mergedValue]
133+ * : +- Aggregate [min(a) AS min(a), sum(b) AS sum(b), avg(cast(c as double)) AS avg(c)]
134+ * : +- Relation [a, b, c]
135+ * +- Join Inner
136+ * :- Project [scalar-subquery [].min(a) AS scalarsubquery()]
137+ * : : +- CTERelationRef 0
138+ * : +- OneRowRelation
139+ * +- Project [scalar-subquery [].sum(b) AS sum(b), scalar-subquery [].avg(c) AS avg(c)]
140+ * : :- CTERelationRef 0
141+ * : +- CTERelationRef 0
142+ * +- OneRowRelation
143+ * ```
86144 *
87- * == Physical Plan ==
88- * *(1) Project [Subquery scalar-subquery#242, [id=#125].avg(a) AS scalarsubquery()#253,
89- * ReusedSubquery
90- * Subquery scalar-subquery#242, [id=#125].sum(a) AS scalarsubquery()#254L]
91- * : :- Subquery scalar-subquery#242, [id=#125]
92- * : : +- *(2) Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
93- * : : +- *(2) HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)],
94- * output=[avg(a)#247, sum(a)#250L])
95- * : : +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#120]
96- * : : +- *(1) HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)],
97- * output=[sum#262, count#263L, sum#264L])
98- * : : +- *(1) ColumnarToRow
99- * : : +- FileScan parquet default.t[a#244] ...
100- * : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
101- * +- *(1) Scan OneRowRelation[]
145+ * Please note that in the above examples the aggregations are part of a "join group", which could
146+ * be rewritten as one aggregate without the need to introduce a CTE and keeping the join. But there
147+ * are more complex cases when this CTE based approach is the only viable option. Such cases include
148+ * when the aggregates reside at different parts of plan, maybe even in different subquery
149+ * expressions.
102150 */
103151object MergeSubplans extends Rule [LogicalPlan ] {
104152 def apply (plan : LogicalPlan ): LogicalPlan = {
@@ -123,7 +171,7 @@ object MergeSubplans extends Rule[LogicalPlan] {
123171
124172 // Traverse level by level and convert merged plans to `CTERelationDef`s and keep non-merged
125173 // ones. While traversing replace references in plans back to `CTERelationRef`s or to original
126- // plans. This is safe as a subplan at a level can reference only lower level ot other subplans.
174+ // plans. This is safe as a subplan at a level can reference only lower level subplans.
127175 val subplansByLevel = ArrayBuffer .empty[IndexedSeq [LogicalPlan ]]
128176 planMergers.foreach { planMerger =>
129177 val mergedPlans = planMerger.mergedPlans()
@@ -162,8 +210,9 @@ object MergeSubplans extends Rule[LogicalPlan] {
162210 }
163211
164212 // First traversal inserts `ScalarSubqueryReference`s and `NoGroupingAggregateReference`s to the
165- // plan and tries to merge subplans by each level. Levels are separated eiter by scalar subqueries
166- // or by non-grouping aggregate nodes. Nodes with the same level make sense to try merging.
213+ // plan and tries to merge subplans by each level. Levels are separated either by scalar
214+ // subqueries or by non-grouping aggregate nodes. Nodes with the same level make sense to try
215+ // merging.
167216 private def insertReferences (
168217 plan : LogicalPlan ,
169218 root : Boolean ,
@@ -224,9 +273,10 @@ object MergeSubplans extends Rule[LogicalPlan] {
224273 // parent
225274 (aggregateReference, level + 1 )
226275 case o =>
227- val (newChildren, levels) = o.children.map(insertReferences(_, false , planMergers)).unzip
276+ val (newChildren, levelsFromChildren) =
277+ o.children.map(insertReferences(_, false , planMergers)).unzip
228278 // Level is the maximum of the level from subqueries and the level from the children.
229- (o.withNewChildren(newChildren), (levelFromSubqueries +: levels ).max)
279+ (o.withNewChildren(newChildren), (levelFromSubqueries +: levelsFromChildren ).max)
230280 }
231281
232282 (planWithReferences, level)
0 commit comments