From 9e6221dfcfe55ab68af9c2d96f2e57014427ff92 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Thu, 4 Sep 2025 13:49:06 -0700 Subject: [PATCH 1/7] init --- .../plans/physical/partitioning.scala | 49 +++++++- .../spark/sql/catalyst/ShuffleSpecSuite.scala | 70 +++++++++++ .../exchange/EnsureRequirements.scala | 23 +++- .../exchange/EnsureRequirementsSuite.scala | 112 ++++++++++++++++++ 4 files changed, 252 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f855483ea3c38..fc6e0f5b522e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -638,7 +638,10 @@ case class ShufflePartitionIdPassThrough( expr: DirectShufflePartitionID, numPartitions: Int) extends Expression with Partitioning with Unevaluable { - // TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + ShufflePartitionIdPassThroughSpec(this, distribution) + } + def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions)) def expressions: Seq[Expression] = expr :: Nil @@ -966,6 +969,50 @@ object KeyGroupedShuffleSpec { } } +case class ShufflePartitionIdPassThroughSpec( + partitioning: ShufflePartitionIdPassThrough, + distribution: ClusteredDistribution) extends ShuffleSpec { + + /** + * A sequence where each element is a set of positions of the partition key to the cluster + * keys. Similar to HashShuffleSpec, this maps the partitioning expression to positions + * in the distribution clustering keys. + */ + lazy val keyPositions: mutable.BitSet = { + val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet] + distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) => + distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) + } + distKeyToPos.getOrElse(partitioning.expr.child.canonicalized, mutable.BitSet.empty) + } + + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + case SinglePartitionShuffleSpec => + partitioning.numPartitions == 1 + case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec( + otherPartitioning, otherDistribution) => + // As ShufflePartitionIdPassThrough only allows a single expression + // as the partitioning expression, we check compatibility as follows: + // 1. Same number of clustering expressions + // 2. Same number of partitions + // 3. each pair of partitioning expression from both sides has overlapping positions in their + // corresponding distributions. + distribution.clustering.length == otherDistribution.clustering.length && + partitioning.numPartitions == otherPartitioning.numPartitions && { + val otherKeyPositions = otherPassThroughSpec.keyPositions + keyPositions.intersect(otherKeyPositions).nonEmpty + } + case ShuffleSpecCollection(specs) => + specs.exists(isCompatibleWith) + case _ => + false + } + + override def canCreatePartitioning: Boolean = false + + override def numPartitions: Int = partitioning.numPartitions +} + case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = { specs.exists(_.isCompatibleWith(other)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index fc5d39fd9c2bb..01ccc390f7dd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.{SparkFunSuite, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.internal.SQLConf @@ -448,6 +449,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { HashPartitioning(Seq($"a", $"b"), 10) ) } + // TODO(shujing): canCreatePartitioning should always return false test("createPartitioning: other specs") { val distribution = ClusteredDistribution(Seq($"a", $"b")) @@ -479,4 +481,72 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { "methodName" -> "createPartitioning$", "className" -> "org.apache.spark.sql.catalyst.plans.physical.ShuffleSpec")) } + + test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") { + val dist = ClusteredDistribution(Seq($"a", $"b")) + val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) + val p2 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10) + + // Identical specs should be compatible + checkCompatible( + p1.createShuffleSpec(dist), + p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), + expected = true + ) + + // Different number of partitions should be incompatible + val p3 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5) + checkCompatible( + p1.createShuffleSpec(dist), + p3.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) + + // Mismatched key positions should be incompatible + val dist1 = ClusteredDistribution(Seq($"a", $"b")) + val p4 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10) // Key at pos 1 + val dist2 = ClusteredDistribution(Seq($"c", $"d")) + val p5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10) // Key at pos 0 + checkCompatible( + p4.createShuffleSpec(dist1), + p5.createShuffleSpec(dist2), + expected = false + ) + + // Mismatched clustering keys + val dist3 = ClusteredDistribution(Seq($"e", $"b")) + checkCompatible( + p1.createShuffleSpec(dist3), + p2.createShuffleSpec(dist2), + expected = false + ) + } + + test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") { + val dist = ClusteredDistribution(Seq($"a", $"b")) + val p = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) + + // Compatibility with SinglePartitionShuffleSpec when numPartitions is 1 + val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1) + checkCompatible( + p1.createShuffleSpec(dist), + SinglePartitionShuffleSpec, + expected = true + ) + + // Incompatible with SinglePartitionShuffleSpec when numPartitions > 1 + checkCompatible( + p.createShuffleSpec(dist), + SinglePartitionShuffleSpec, + expected = false + ) + + // Incompatible with HashShuffleSpec + val p2 = HashPartitioning(Seq($"c"), 10) + checkCompatible( + p.createShuffleSpec(dist), + p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), + expected = false + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index a0fc4b65fdbf3..2e605e7709991 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -179,8 +179,13 @@ case class EnsureRequirements( newChildren.isDefined } + val isShufflePassThroughCompatible = !isKeyGroupCompatible && + parent.isDefined && children.length == 2 && childrenIndexes.length == 2 && + checkShufflePartitionIdPassThroughCompatible( + children.head, children(1), requiredChildDistributions) + children = children.zip(requiredChildDistributions).zipWithIndex.map { - case ((child, _), idx) if isKeyGroupCompatible || !childrenIndexes.contains(idx) => + case ((child, _), idx) if isKeyGroupCompatible || isShufflePassThroughCompatible || !childrenIndexes.contains(idx) => child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { @@ -600,6 +605,22 @@ case class EnsureRequirements( if (isCompatible) Some(Seq(newLeft, newRight)) else None } + private def checkShufflePartitionIdPassThroughCompatible( + left: SparkPlan, + right: SparkPlan, + requiredChildDistribution: Seq[Distribution]): Boolean = { + (left.outputPartitioning, right.outputPartitioning) match { + case (p1: ShufflePartitionIdPassThrough, p2: ShufflePartitionIdPassThrough) => + val leftSpec = p1.createShuffleSpec( + requiredChildDistribution.head.asInstanceOf[ClusteredDistribution]) + val rightSpec = p2.createShuffleSpec( + requiredChildDistribution(1).asInstanceOf[ClusteredDistribution]) + leftSpec.isCompatibleWith(rightSpec) + case _ => + false + } + } + // Similar to `OptimizeSkewedJoin.canSplitRightSide` private def canReplicateLeftSide(joinType: JoinType): Boolean = { joinType == Inner || joinType == Cross || joinType == RightOuter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3b0bb088a1076..0b581a83674f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.exchange import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.plans.Inner @@ -1196,6 +1197,117 @@ class EnsureRequirementsSuite extends SharedSparkSession { TransformExpression(BucketFunction, expr, Some(numBuckets)) } + test("ShufflePartitionIdPassThrough - always shuffles due to canCreatePartitioning=false") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Even with identical partitioning and join keys, shuffles are added + // because ShufflePartitionIdPassThroughSpec.canCreatePartitioning = false + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + assert(leftKeys === Seq(exprA)) + assert(rightKeys === Seq(exprA)) + case other => fail(s"We don't expect shuffle on neither sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough incompatibility - different partitions") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Different number of partitions - should add shuffles + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8)) + val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + // Both sides should be shuffled to default partitions + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprA)) + assert(p2.expressions == Seq(exprB)) + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough incompatibility - key position mismatch") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Key position mismatch - should add shuffles + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5)) + // Join on different keys than partitioning keys + val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => + // Both sides shuffled due to key mismatch + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // ShufflePartitionIdPassThrough vs HashPartitioning - always adds shuffles + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), + SortExec(_, _, _: DummySparkPlan, _), _) => + // Left side shuffled, right side kept as-is + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, _: DummySparkPlan, _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => + // Right side shuffled, left side kept as-is + case other => fail(s"Expected shuffle on at least one side, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + // Even when compatible (numPartitions=1), shuffles added due to canCreatePartitioning=false + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1)) + val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) + val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => + // Both sides shuffled due to canCreatePartitioning = false + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + def years(expr: Expression): TransformExpression = { TransformExpression(YearsFunction, Seq(expr)) } From 62aa22ee38c6a90f454a4c5b67340e2658fa5fd9 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 24 Sep 2025 17:38:56 -0700 Subject: [PATCH 2/7] init --- .../plans/physical/partitioning.scala | 3 +- .../exchange/EnsureRequirements.scala | 3 +- .../exchange/EnsureRequirementsSuite.scala | 142 +++++++++++++++++- 3 files changed, 142 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index fc6e0f5b522e2..7712ae7f7a23b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -638,6 +638,7 @@ case class ShufflePartitionIdPassThrough( expr: DirectShufflePartitionID, numPartitions: Int) extends Expression with Partitioning with Unevaluable { + // We don't support creating partitioning for ShufflePartitionIdPassThrough. override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { ShufflePartitionIdPassThroughSpec(this, distribution) } @@ -987,8 +988,6 @@ case class ShufflePartitionIdPassThroughSpec( } override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { - case SinglePartitionShuffleSpec => - partitioning.numPartitions == 1 case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec( otherPartitioning, otherDistribution) => // As ShufflePartitionIdPassThrough only allows a single expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 2e605e7709991..b80c24272edd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -185,7 +185,8 @@ case class EnsureRequirements( children.head, children(1), requiredChildDistributions) children = children.zip(requiredChildDistributions).zipWithIndex.map { - case ((child, _), idx) if isKeyGroupCompatible || isShufflePassThroughCompatible || !childrenIndexes.contains(idx) => + case ((child, _), idx) if isKeyGroupCompatible || isShufflePassThroughCompatible || + !childrenIndexes.contains(idx) => child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 0b581a83674f6..4887c0b704b02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1197,10 +1197,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { TransformExpression(BucketFunction, expr, Some(numBuckets)) } - test("ShufflePartitionIdPassThrough - always shuffles due to canCreatePartitioning=false") { + test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are compatible") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - // Even with identical partitioning and join keys, shuffles are added - // because ShufflePartitionIdPassThroughSpec.canCreatePartitioning = false val plan1 = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) val plan2 = DummySparkPlan( @@ -1308,6 +1306,144 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("ShufflePartitionIdPassThrough - incompatible due to different expressions " + + "with same base column") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Even though both use exprA as base and have same numPartitions, + // different Pmod operations make them incompatible + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough( + DirectShufflePartitionID(Pmod(exprA, Literal(10))), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough( + DirectShufflePartitionID(Pmod(exprA, Literal(5))), 5)) + val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + // Both sides should be shuffled due to expression mismatch + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprA)) + assert(p2.expressions == Seq(exprA)) + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough - compatible with multiple clustering keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Both partitioned by exprA, joined on (exprA, exprB) + // Should be compatible because exprA positions overlap + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(s"We don't expect shuffle on neither sides with multiple " + + s"clustering keys, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough - incompatible when partition key not in join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Partitioned by exprA and exprB respectively, but joining on completely different keys + // Should require shuffles because partition keys don't match join keys + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val smjExec = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + // Both sides should be shuffled because partition keys not in join keys + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprC)) + assert(p2.expressions == Seq(exprD)) + case other => fail(s"Expected shuffles on both sides due to key mismatch, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough - cross position matching behavior") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Left partitioned by exprA, right partitioned by exprB + // Both sides join on (exprA, exprB) + // Test if cross-position matching works: left partition key exprA matches right join key + // exprA (pos 0) + // and right partition key exprB matches left join key exprB (pos 1) + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprA, exprB)) + assert(p2.expressions == Seq(exprA, exprB)) + case other => fail(s"Expected either no shuffles (if compatible) or shuffles on " + + s"both sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough - compatible when partition key matches at any position") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Both sides partitioned by exprB and join on (exprA, exprB) + // Should be compatible because partition key exprB matches at position 1 in join keys + val plan1 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val plan2 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + plan1, plan2) + + EnsureRequirements.apply(smjExec) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + // No shuffles because exprB (partition key) appears at position 1 in join keys + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(s"Expected no shuffles due to position overlap at position 1, " + + s"but got: $other") + } + } + } + def years(expr: Expression): TransformExpression = { TransformExpression(YearsFunction, Seq(expr)) } From d2fab0e5e10f843f76794796bf920b06c7a8d0e6 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 24 Sep 2025 23:57:34 -0700 Subject: [PATCH 3/7] SinglePartitionShuffleSpec --- .../apache/spark/sql/catalyst/plans/physical/partitioning.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 7712ae7f7a23b..6b110cab262e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -988,6 +988,8 @@ case class ShufflePartitionIdPassThroughSpec( } override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + case SinglePartitionShuffleSpec => + partitioning.numPartitions == 1 case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec( otherPartitioning, otherDistribution) => // As ShufflePartitionIdPassThrough only allows a single expression From c53123d63d4df14adea4414f06ae21721b926cbb Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Thu, 25 Sep 2025 00:01:17 -0700 Subject: [PATCH 4/7] lint --- .../spark/sql/catalyst/ShuffleSpecSuite.scala | 1 - .../exchange/EnsureRequirements.scala | 23 ++++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 01ccc390f7dd1..dac9b1a906b32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -449,7 +449,6 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { HashPartitioning(Seq($"a", $"b"), 10) ) } - // TODO(shujing): canCreatePartitioning should always return false test("createPartitioning: other specs") { val distribution = ClusteredDistribution(Seq($"a", $"b")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b80c24272edd9..06590c20ac28d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -609,17 +609,18 @@ case class EnsureRequirements( private def checkShufflePartitionIdPassThroughCompatible( left: SparkPlan, right: SparkPlan, - requiredChildDistribution: Seq[Distribution]): Boolean = { - (left.outputPartitioning, right.outputPartitioning) match { - case (p1: ShufflePartitionIdPassThrough, p2: ShufflePartitionIdPassThrough) => - val leftSpec = p1.createShuffleSpec( - requiredChildDistribution.head.asInstanceOf[ClusteredDistribution]) - val rightSpec = p2.createShuffleSpec( - requiredChildDistribution(1).asInstanceOf[ClusteredDistribution]) - leftSpec.isCompatibleWith(rightSpec) - case _ => - false - } + requiredChildDistribution: Seq[Distribution]): Boolean = { + (left.outputPartitioning, right.outputPartitioning) match { + case (p1: ShufflePartitionIdPassThrough, p2: ShufflePartitionIdPassThrough) => + assert(requiredChildDistribution.length == 2) + val leftSpec = p1.createShuffleSpec( + requiredChildDistribution.head.asInstanceOf[ClusteredDistribution]) + val rightSpec = p2.createShuffleSpec( + requiredChildDistribution(1).asInstanceOf[ClusteredDistribution]) + leftSpec.isCompatibleWith(rightSpec) + case _ => + false + } } // Similar to `OptimizeSkewedJoin.canSplitRightSide` From 4e52066390c09ad2c84a5a36169fd3f7a5002f10 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Thu, 25 Sep 2025 11:05:30 -0700 Subject: [PATCH 5/7] address comments --- .../plans/physical/partitioning.scala | 2 +- .../spark/sql/catalyst/ShuffleSpecSuite.scala | 2 +- .../exchange/EnsureRequirements.scala | 19 ++-- .../exchange/EnsureRequirementsSuite.scala | 87 +++++++------------ 4 files changed, 42 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 6b110cab262e9..d09af14350a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -638,7 +638,6 @@ case class ShufflePartitionIdPassThrough( expr: DirectShufflePartitionID, numPartitions: Int) extends Expression with Partitioning with Unevaluable { - // We don't support creating partitioning for ShufflePartitionIdPassThrough. override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { ShufflePartitionIdPassThroughSpec(this, distribution) } @@ -1009,6 +1008,7 @@ case class ShufflePartitionIdPassThroughSpec( false } + // We don't support creating partitioning for ShufflePartitionIdPassThrough. override def canCreatePartitioning: Boolean = false override def numPartitions: Int = partitioning.numPartitions diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index dac9b1a906b32..a885c8f248f45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -516,7 +516,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { val dist3 = ClusteredDistribution(Seq($"e", $"b")) checkCompatible( p1.createShuffleSpec(dist3), - p2.createShuffleSpec(dist2), + p2.createShuffleSpec(dist), expected = false ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 06590c20ac28d..cf9ffc2fd876e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -165,27 +165,28 @@ case class EnsureRequirements( // Check if the following conditions are satisfied: // 1. There are exactly two children (e.g., join). Note that Spark doesn't support // multi-way join at the moment, so this check should be sufficient. - // 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other + // 2. All children are of the same partitioning, and they are compatible with each other // If both are true, skip shuffle. - val isKeyGroupCompatible = parent.isDefined && + val areChildrenCompatible = parent.isDefined && children.length == 2 && childrenIndexes.length == 2 && { val left = children.head val right = children(1) + + // key group compatibility check val newChildren = checkKeyGroupCompatible( parent.get, left, right, requiredChildDistributions) if (newChildren.isDefined) { children = newChildren.get + true + } else { + // If key group check fails, check ShufflePartitionIdPassThrough compatibility + checkShufflePartitionIdPassThroughCompatible( + left, right, requiredChildDistributions) } - newChildren.isDefined } - val isShufflePassThroughCompatible = !isKeyGroupCompatible && - parent.isDefined && children.length == 2 && childrenIndexes.length == 2 && - checkShufflePartitionIdPassThroughCompatible( - children.head, children(1), requiredChildDistributions) - children = children.zip(requiredChildDistributions).zipWithIndex.map { - case ((child, _), idx) if isKeyGroupCompatible || isShufflePassThroughCompatible || + case ((child, _), idx) if areChildrenCompatible || !childrenIndexes.contains(idx) => child case ((child, dist), idx) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 4887c0b704b02..9d68d90e39c0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1197,7 +1197,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { TransformExpression(BucketFunction, expr, Some(numBuckets)) } - test("ShufflePartitionIdPassThrough - avoid necessary shuffle when they are compatible") { + test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { val plan1 = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) @@ -1253,7 +1253,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { val plan2 = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5)) // Join on different keys than partitioning keys - val smjExec = SortMergeJoinExec(exprB :: Nil, exprD :: Nil, Inner, None, plan1, plan2) + val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None, + plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, @@ -1306,32 +1307,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("ShufflePartitionIdPassThrough - incompatible due to different expressions " + - "with same base column") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - // Even though both use exprA as base and have same numPartitions, - // different Pmod operations make them incompatible - val plan1 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough( - DirectShufflePartitionID(Pmod(exprA, Literal(10))), 5)) - val plan2 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough( - DirectShufflePartitionID(Pmod(exprA, Literal(5))), 5)) - val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, plan1, plan2) - - EnsureRequirements.apply(smjExec) match { - case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), - SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => - // Both sides should be shuffled due to expression mismatch - assert(p1.numPartitions == 10) - assert(p2.numPartitions == 10) - assert(p1.expressions == Seq(exprA)) - assert(p2.expressions == Seq(exprA)) - case other => fail(s"Expected shuffles on both sides, but got: $other") - } - } - } test("ShufflePartitionIdPassThrough - compatible with multiple clustering keys") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { @@ -1359,6 +1334,33 @@ class EnsureRequirementsSuite extends SharedSparkSession { case other => fail(s"We don't expect shuffle on neither sides with multiple " + s"clustering keys, but got: $other") } + + // Test case 2: partition key matches at position 1 + // Both sides partitioned by exprB and join on (exprA, exprB) + // Should be compatible because partition key exprB matches at position 1 in join keys + val plan3 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val plan4 = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val smjExec2 = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + plan3, plan4) + + EnsureRequirements.apply(smjExec2) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + // No shuffles because exprB (partition key) appears at position 1 in join keys + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(s"Expected no shuffles due to position overlap at position 1, " + + s"but got: $other") + } } } @@ -1414,35 +1416,6 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("ShufflePartitionIdPassThrough - compatible when partition key matches at any position") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - // Both sides partitioned by exprB and join on (exprA, exprB) - // Should be compatible because partition key exprB matches at position 1 in join keys - val plan1 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) - val plan2 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) - val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, - plan1, plan2) - - EnsureRequirements.apply(smjExec) match { - case SortMergeJoinExec( - leftKeys, - rightKeys, - _, - _, - SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), - _ - ) => - // No shuffles because exprB (partition key) appears at position 1 in join keys - assert(leftKeys === Seq(exprA, exprB)) - assert(rightKeys === Seq(exprA, exprB)) - case other => fail(s"Expected no shuffles due to position overlap at position 1, " + - s"but got: $other") - } - } - } def years(expr: Expression): TransformExpression = { TransformExpression(YearsFunction, Seq(expr)) From 69df8352ff2d62b3aba02142aeb05cc7359cae57 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 29 Sep 2025 00:03:33 -0700 Subject: [PATCH 6/7] address comments --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d09af14350a9a..1cbb49c7a1f73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -995,7 +995,7 @@ case class ShufflePartitionIdPassThroughSpec( // as the partitioning expression, we check compatibility as follows: // 1. Same number of clustering expressions // 2. Same number of partitions - // 3. each pair of partitioning expression from both sides has overlapping positions in their + // 3. each partitioning expression from both sides has overlapping positions in their // corresponding distributions. distribution.clustering.length == otherDistribution.clustering.length && partitioning.numPartitions == otherPartitioning.numPartitions && { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index a885c8f248f45..fd783385f1640 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -523,7 +523,6 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") { val dist = ClusteredDistribution(Seq($"a", $"b")) - val p = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) // Compatibility with SinglePartitionShuffleSpec when numPartitions is 1 val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1) @@ -543,7 +542,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { // Incompatible with HashShuffleSpec val p2 = HashPartitioning(Seq($"c"), 10) checkCompatible( - p.createShuffleSpec(dist), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10).createShuffleSpec(dist), p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), expected = false ) From f907b5b8f9e4757f87793689e33475fdb7b58838 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 1 Oct 2025 15:24:30 -0700 Subject: [PATCH 7/7] ckp --- .../spark/sql/catalyst/ShuffleSpecSuite.scala | 42 ++++----- .../exchange/EnsureRequirementsSuite.scala | 91 +++++++++---------- 2 files changed, 63 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index fd783385f1640..a41f5146386f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -482,68 +482,62 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { } test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") { - val dist = ClusteredDistribution(Seq($"a", $"b")) - val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) - val p2 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10) + val ab = ClusteredDistribution(Seq($"a", $"b")) + val cd = ClusteredDistribution(Seq($"c", $"d")) + val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) // Identical specs should be compatible checkCompatible( - p1.createShuffleSpec(dist), - p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), + passThrough_a_10.createShuffleSpec(ab), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd), expected = true ) // Different number of partitions should be incompatible - val p3 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5) checkCompatible( - p1.createShuffleSpec(dist), - p3.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), + passThrough_a_10.createShuffleSpec(ab), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5).createShuffleSpec(cd), expected = false ) // Mismatched key positions should be incompatible - val dist1 = ClusteredDistribution(Seq($"a", $"b")) - val p4 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10) // Key at pos 1 - val dist2 = ClusteredDistribution(Seq($"c", $"d")) - val p5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10) // Key at pos 0 checkCompatible( - p4.createShuffleSpec(dist1), - p5.createShuffleSpec(dist2), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10).createShuffleSpec(ab), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd), expected = false ) // Mismatched clustering keys - val dist3 = ClusteredDistribution(Seq($"e", $"b")) checkCompatible( - p1.createShuffleSpec(dist3), - p2.createShuffleSpec(dist), + passThrough_a_10.createShuffleSpec(ClusteredDistribution(Seq($"e", $"b"))), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(ab), expected = false ) } test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") { - val dist = ClusteredDistribution(Seq($"a", $"b")) + val ab = ClusteredDistribution(Seq($"a", $"b")) + val cd = ClusteredDistribution(Seq($"c", $"d")) + val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) // Compatibility with SinglePartitionShuffleSpec when numPartitions is 1 - val p1 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1) checkCompatible( - p1.createShuffleSpec(dist), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1).createShuffleSpec(ab), SinglePartitionShuffleSpec, expected = true ) // Incompatible with SinglePartitionShuffleSpec when numPartitions > 1 checkCompatible( - p.createShuffleSpec(dist), + passThrough_a_10.createShuffleSpec(ab), SinglePartitionShuffleSpec, expected = false ) // Incompatible with HashShuffleSpec - val p2 = HashPartitioning(Seq($"c"), 10) checkCompatible( - ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10).createShuffleSpec(dist), - p2.createShuffleSpec(ClusteredDistribution(Seq($"c", $"d"))), + passThrough_a_10.createShuffleSpec(ab), + HashShuffleSpec(HashPartitioning(Seq($"c"), 10), cd), expected = false ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 9d68d90e39c0b..7b9b950e31b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1199,13 +1199,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - val plan1 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val smjExec = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, plan1, plan2) + val passThrough_a_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5) - EnsureRequirements.apply(smjExec) match { + val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5) + val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5) + val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { case SortMergeJoinExec( leftKeys, rightKeys, @@ -1225,13 +1225,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("ShufflePartitionIdPassThrough incompatibility - different partitions") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { // Different number of partitions - should add shuffles - val plan1 = DummySparkPlan( + val leftPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( + val rightPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8)) - val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2) + val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(join) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => @@ -1248,15 +1248,15 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("ShufflePartitionIdPassThrough incompatibility - key position mismatch") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { // Key position mismatch - should add shuffles - val plan1 = DummySparkPlan( + val leftPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( + val rightPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5)) // Join on different keys than partitioning keys - val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None, - plan1, plan2) + val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None, + leftPlan, rightPlan) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(join) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => @@ -1269,13 +1269,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { // ShufflePartitionIdPassThrough vs HashPartitioning - always adds shuffles - val plan1 = DummySparkPlan( + val leftPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( + val rightPlan = DummySparkPlan( outputPartitioning = HashPartitioning(exprB :: Nil, 5)) - val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2) + val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(join) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), SortExec(_, _, _: DummySparkPlan, _), _) => @@ -1292,12 +1292,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { // Even when compatible (numPartitions=1), shuffles added due to canCreatePartitioning=false - val plan1 = DummySparkPlan( + val leftPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1)) - val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) - val smjExec = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, plan1, plan2) + val rightPlan = DummySparkPlan(outputPartitioning = SinglePartition) + val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(join) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => @@ -1310,16 +1310,17 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("ShufflePartitionIdPassThrough - compatible with multiple clustering keys") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + val passThrough_a_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5) + val passThrough_b_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5) + // Both partitioned by exprA, joined on (exprA, exprB) // Should be compatible because exprA positions overlap - val plan1 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, - plan1, plan2) + val leftPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5) + val rightPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5) + val joinA = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + leftPlanA, rightPlanA) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(joinA) match { case SortMergeJoinExec( leftKeys, rightKeys, @@ -1338,14 +1339,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // Test case 2: partition key matches at position 1 // Both sides partitioned by exprB and join on (exprA, exprB) // Should be compatible because partition key exprB matches at position 1 in join keys - val plan3 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) - val plan4 = DummySparkPlan( - outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) - val smjExec2 = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, - plan3, plan4) + val leftPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5) + val rightPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5) + val joinB = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + leftPlanB, rightPlanB) - EnsureRequirements.apply(smjExec2) match { + EnsureRequirements.apply(joinB) match { case SortMergeJoinExec( leftKeys, rightKeys, @@ -1368,13 +1367,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { // Partitioned by exprA and exprB respectively, but joining on completely different keys // Should require shuffles because partition keys don't match join keys - val plan1 = DummySparkPlan( + val leftPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( + val rightPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) - val smjExec = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, plan1, plan2) + val join = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, leftPlan, rightPlan) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(join) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => @@ -1395,14 +1394,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { // Test if cross-position matching works: left partition key exprA matches right join key // exprA (pos 0) // and right partition key exprB matches left join key exprB (pos 1) - val plan1 = DummySparkPlan( + val leftPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) - val plan2 = DummySparkPlan( + val rightPlan = DummySparkPlan( outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) - val smjExec = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, - plan1, plan2) + val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + leftPlan, rightPlan) - EnsureRequirements.apply(smjExec) match { + EnsureRequirements.apply(join) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) =>