diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f855483ea3c38..1cbb49c7a1f73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -638,7 +638,10 @@ case class ShufflePartitionIdPassThrough( expr: DirectShufflePartitionID, numPartitions: Int) extends Expression with Partitioning with Unevaluable { - // TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + ShufflePartitionIdPassThroughSpec(this, distribution) + } + def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions)) def expressions: Seq[Expression] = expr :: Nil @@ -966,6 +969,51 @@ object KeyGroupedShuffleSpec { } } +case class ShufflePartitionIdPassThroughSpec( + partitioning: ShufflePartitionIdPassThrough, + distribution: ClusteredDistribution) extends ShuffleSpec { + + /** + * A sequence where each element is a set of positions of the partition key to the cluster + * keys. Similar to HashShuffleSpec, this maps the partitioning expression to positions + * in the distribution clustering keys. + */ + lazy val keyPositions: mutable.BitSet = { + val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet] + distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) => + distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) + } + distKeyToPos.getOrElse(partitioning.expr.child.canonicalized, mutable.BitSet.empty) + } + + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + case SinglePartitionShuffleSpec => + partitioning.numPartitions == 1 + case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec( + otherPartitioning, otherDistribution) => + // As ShufflePartitionIdPassThrough only allows a single expression + // as the partitioning expression, we check compatibility as follows: + // 1. Same number of clustering expressions + // 2. Same number of partitions + // 3. each partitioning expression from both sides has overlapping positions in their + // corresponding distributions. + distribution.clustering.length == otherDistribution.clustering.length && + partitioning.numPartitions == otherPartitioning.numPartitions && { + val otherKeyPositions = otherPassThroughSpec.keyPositions + keyPositions.intersect(otherKeyPositions).nonEmpty + } + case ShuffleSpecCollection(specs) => + specs.exists(isCompatibleWith) + case _ => + false + } + + // We don't support creating partitioning for ShufflePartitionIdPassThrough. + override def canCreatePartitioning: Boolean = false + + override def numPartitions: Int = partitioning.numPartitions +} + case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = { specs.exists(_.isCompatibleWith(other)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index fc5d39fd9c2bb..a41f5146386f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.{SparkFunSuite, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.internal.SQLConf @@ -479,4 +480,65 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { "methodName" -> "createPartitioning$", "className" -> "org.apache.spark.sql.catalyst.plans.physical.ShuffleSpec")) } + + test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") { + val ab = ClusteredDistribution(Seq($"a", $"b")) + val cd = ClusteredDistribution(Seq($"c", $"d")) + val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) + + // Identical specs should be compatible + checkCompatible( + passThrough_a_10.createShuffleSpec(ab), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd), + expected = true + ) + + // Different number of partitions should be incompatible + checkCompatible( + passThrough_a_10.createShuffleSpec(ab), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5).createShuffleSpec(cd), + expected = false + ) + + // Mismatched key positions should be incompatible + checkCompatible( + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10).createShuffleSpec(ab), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd), + expected = false + ) + + // Mismatched clustering keys + checkCompatible( + passThrough_a_10.createShuffleSpec(ClusteredDistribution(Seq($"e", $"b"))), + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(ab), + expected = false + ) + } + + test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") { + val ab = ClusteredDistribution(Seq($"a", $"b")) + val cd = ClusteredDistribution(Seq($"c", $"d")) + val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10) + + // Compatibility with SinglePartitionShuffleSpec when numPartitions is 1 + checkCompatible( + ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1).createShuffleSpec(ab), + SinglePartitionShuffleSpec, + expected = true + ) + + // Incompatible with SinglePartitionShuffleSpec when numPartitions > 1 + checkCompatible( + passThrough_a_10.createShuffleSpec(ab), + SinglePartitionShuffleSpec, + expected = false + ) + + // Incompatible with HashShuffleSpec + checkCompatible( + passThrough_a_10.createShuffleSpec(ab), + HashShuffleSpec(HashPartitioning(Seq($"c"), 10), cd), + expected = false + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index a0fc4b65fdbf3..cf9ffc2fd876e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -165,22 +165,29 @@ case class EnsureRequirements( // Check if the following conditions are satisfied: // 1. There are exactly two children (e.g., join). Note that Spark doesn't support // multi-way join at the moment, so this check should be sufficient. - // 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other + // 2. All children are of the same partitioning, and they are compatible with each other // If both are true, skip shuffle. - val isKeyGroupCompatible = parent.isDefined && + val areChildrenCompatible = parent.isDefined && children.length == 2 && childrenIndexes.length == 2 && { val left = children.head val right = children(1) + + // key group compatibility check val newChildren = checkKeyGroupCompatible( parent.get, left, right, requiredChildDistributions) if (newChildren.isDefined) { children = newChildren.get + true + } else { + // If key group check fails, check ShufflePartitionIdPassThrough compatibility + checkShufflePartitionIdPassThroughCompatible( + left, right, requiredChildDistributions) } - newChildren.isDefined } children = children.zip(requiredChildDistributions).zipWithIndex.map { - case ((child, _), idx) if isKeyGroupCompatible || !childrenIndexes.contains(idx) => + case ((child, _), idx) if areChildrenCompatible || + !childrenIndexes.contains(idx) => child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { @@ -600,6 +607,23 @@ case class EnsureRequirements( if (isCompatible) Some(Seq(newLeft, newRight)) else None } + private def checkShufflePartitionIdPassThroughCompatible( + left: SparkPlan, + right: SparkPlan, + requiredChildDistribution: Seq[Distribution]): Boolean = { + (left.outputPartitioning, right.outputPartitioning) match { + case (p1: ShufflePartitionIdPassThrough, p2: ShufflePartitionIdPassThrough) => + assert(requiredChildDistribution.length == 2) + val leftSpec = p1.createShuffleSpec( + requiredChildDistribution.head.asInstanceOf[ClusteredDistribution]) + val rightSpec = p2.createShuffleSpec( + requiredChildDistribution(1).asInstanceOf[ClusteredDistribution]) + leftSpec.isCompatibleWith(rightSpec) + case _ => + false + } + } + // Similar to `OptimizeSkewedJoin.canSplitRightSide` private def canReplicateLeftSide(joinType: JoinType): Boolean = { joinType == Inner || joinType == Cross || joinType == RightOuter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3b0bb088a1076..7b9b950e31b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.exchange import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.plans.Inner @@ -1196,6 +1197,225 @@ class EnsureRequirementsSuite extends SharedSparkSession { TransformExpression(BucketFunction, expr, Some(numBuckets)) } + test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when children are compatible") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + val passThrough_a_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5) + + val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5) + val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5) + val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + assert(leftKeys === Seq(exprA)) + assert(rightKeys === Seq(exprA)) + case other => fail(s"We don't expect shuffle on neither sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough incompatibility - different partitions") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Different number of partitions - should add shuffles + val leftPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val rightPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8)) + val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + // Both sides should be shuffled to default partitions + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprA)) + assert(p2.expressions == Seq(exprB)) + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough incompatibility - key position mismatch") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Key position mismatch - should add shuffles + val leftPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val rightPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5)) + // Join on different keys than partitioning keys + val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: Nil, Inner, None, + leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => + // Both sides shuffled due to key mismatch + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // ShufflePartitionIdPassThrough vs HashPartitioning - always adds shuffles + val leftPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val rightPlan = DummySparkPlan( + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), + SortExec(_, _, _: DummySparkPlan, _), _) => + // Left side shuffled, right side kept as-is + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, _: DummySparkPlan, _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => + // Right side shuffled, left side kept as-is + case other => fail(s"Expected shuffle on at least one side, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + // Even when compatible (numPartitions=1), shuffles added due to canCreatePartitioning=false + val leftPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1)) + val rightPlan = DummySparkPlan(outputPartitioning = SinglePartition) + val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _), _) => + // Both sides shuffled due to canCreatePartitioning = false + case other => fail(s"Expected shuffles on both sides, but got: $other") + } + } + } + + + test("ShufflePartitionIdPassThrough - compatible with multiple clustering keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + val passThrough_a_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5) + val passThrough_b_5 = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5) + + // Both partitioned by exprA, joined on (exprA, exprB) + // Should be compatible because exprA positions overlap + val leftPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5) + val rightPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5) + val joinA = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + leftPlanA, rightPlanA) + + EnsureRequirements.apply(joinA) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(s"We don't expect shuffle on neither sides with multiple " + + s"clustering keys, but got: $other") + } + + // Test case 2: partition key matches at position 1 + // Both sides partitioned by exprB and join on (exprA, exprB) + // Should be compatible because partition key exprB matches at position 1 in join keys + val leftPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5) + val rightPlanB = DummySparkPlan(outputPartitioning = passThrough_b_5) + val joinB = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + leftPlanB, rightPlanB) + + EnsureRequirements.apply(joinB) match { + case SortMergeJoinExec( + leftKeys, + rightKeys, + _, + _, + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: ShufflePartitionIdPassThrough, _, _), _), + _ + ) => + // No shuffles because exprB (partition key) appears at position 1 in join keys + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(s"Expected no shuffles due to position overlap at position 1, " + + s"but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough - incompatible when partition key not in join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Partitioned by exprA and exprB respectively, but joining on completely different keys + // Should require shuffles because partition keys don't match join keys + val leftPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val rightPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val join = SortMergeJoinExec(exprC :: Nil, exprD :: Nil, Inner, None, leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + // Both sides should be shuffled because partition keys not in join keys + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprC)) + assert(p2.expressions == Seq(exprD)) + case other => fail(s"Expected shuffles on both sides due to key mismatch, but got: $other") + } + } + } + + test("ShufflePartitionIdPassThrough - cross position matching behavior") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + // Left partitioned by exprA, right partitioned by exprB + // Both sides join on (exprA, exprB) + // Test if cross-position matching works: left partition key exprA matches right join key + // exprA (pos 0) + // and right partition key exprB matches left join key exprB (pos 1) + val leftPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)) + val rightPlan = DummySparkPlan( + outputPartitioning = ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)) + val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, + leftPlan, rightPlan) + + EnsureRequirements.apply(join) match { + case SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), _), + SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), _), _) => + assert(p1.numPartitions == 10) + assert(p2.numPartitions == 10) + assert(p1.expressions == Seq(exprA, exprB)) + assert(p2.expressions == Seq(exprA, exprB)) + case other => fail(s"Expected either no shuffles (if compatible) or shuffles on " + + s"both sides, but got: $other") + } + } + } + + def years(expr: Expression): TransformExpression = { TransformExpression(YearsFunction, Seq(expr)) }