Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,10 @@ case class ShufflePartitionIdPassThrough(
expr: DirectShufflePartitionID,
numPartitions: Int) extends Expression with Partitioning with Unevaluable {

// TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
ShufflePartitionIdPassThroughSpec(this, distribution)
}

def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions))

def expressions: Seq[Expression] = expr :: Nil
Expand Down Expand Up @@ -966,6 +969,51 @@ object KeyGroupedShuffleSpec {
}
}

case class ShufflePartitionIdPassThroughSpec(
partitioning: ShufflePartitionIdPassThrough,
distribution: ClusteredDistribution) extends ShuffleSpec {

/**
* A sequence where each element is a set of positions of the partition key to the cluster
* keys. Similar to HashShuffleSpec, this maps the partitioning expression to positions
* in the distribution clustering keys.
*/
lazy val keyPositions: mutable.BitSet = {
val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
distKeyToPos.getOrElse(partitioning.expr.child.canonicalized, mutable.BitSet.empty)
}

override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
case SinglePartitionShuffleSpec =>
partitioning.numPartitions == 1
case otherPassThroughSpec @ ShufflePartitionIdPassThroughSpec(
otherPartitioning, otherDistribution) =>
// As ShufflePartitionIdPassThrough only allows a single expression
// as the partitioning expression, we check compatibility as follows:
// 1. Same number of clustering expressions
// 2. Same number of partitions
// 3. each partitioning expression from both sides has overlapping positions in their
// corresponding distributions.
distribution.clustering.length == otherDistribution.clustering.length &&
partitioning.numPartitions == otherPartitioning.numPartitions && {
val otherKeyPositions = otherPassThroughSpec.keyPositions
keyPositions.intersect(otherKeyPositions).nonEmpty
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ =>
false
}

// We don't support creating partitioning for ShufflePartitionIdPassThrough.
override def canCreatePartitioning: Boolean = false

override def numPartitions: Int = partitioning.numPartitions
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
override def isCompatibleWith(other: ShuffleSpec): Boolean = {
specs.exists(_.isCompatibleWith(other))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.{SparkFunSuite, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.DirectShufflePartitionID
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -479,4 +480,65 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
"methodName" -> "createPartitioning$",
"className" -> "org.apache.spark.sql.catalyst.plans.physical.ShuffleSpec"))
}

test("compatibility: ShufflePartitionIdPassThroughSpec on both sides") {
val ab = ClusteredDistribution(Seq($"a", $"b"))
val cd = ClusteredDistribution(Seq($"c", $"d"))
val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)

// Identical specs should be compatible
checkCompatible(
passThrough_a_10.createShuffleSpec(ab),
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd),
expected = true
)

// Different number of partitions should be incompatible
checkCompatible(
passThrough_a_10.createShuffleSpec(ab),
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 5).createShuffleSpec(cd),
expected = false
)

// Mismatched key positions should be incompatible
checkCompatible(
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"b"), 10).createShuffleSpec(ab),
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(cd),
expected = false
)

// Mismatched clustering keys
checkCompatible(
passThrough_a_10.createShuffleSpec(ClusteredDistribution(Seq($"e", $"b"))),
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"c"), 10).createShuffleSpec(ab),
expected = false
)
}

test("compatibility: ShufflePartitionIdPassThroughSpec vs other specs") {
val ab = ClusteredDistribution(Seq($"a", $"b"))
val cd = ClusteredDistribution(Seq($"c", $"d"))
val passThrough_a_10 = ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 10)

// Compatibility with SinglePartitionShuffleSpec when numPartitions is 1
checkCompatible(
ShufflePartitionIdPassThrough(DirectShufflePartitionID($"a"), 1).createShuffleSpec(ab),
SinglePartitionShuffleSpec,
expected = true
)

// Incompatible with SinglePartitionShuffleSpec when numPartitions > 1
checkCompatible(
passThrough_a_10.createShuffleSpec(ab),
SinglePartitionShuffleSpec,
expected = false
)

// Incompatible with HashShuffleSpec
checkCompatible(
passThrough_a_10.createShuffleSpec(ab),
HashShuffleSpec(HashPartitioning(Seq($"c"), 10), cd),
expected = false
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,22 +165,29 @@ case class EnsureRequirements(
// Check if the following conditions are satisfied:
// 1. There are exactly two children (e.g., join). Note that Spark doesn't support
// multi-way join at the moment, so this check should be sufficient.
// 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other
// 2. All children are of the same partitioning, and they are compatible with each other
// If both are true, skip shuffle.
val isKeyGroupCompatible = parent.isDefined &&
val areChildrenCompatible = parent.isDefined &&
children.length == 2 && childrenIndexes.length == 2 && {
val left = children.head
val right = children(1)

// key group compatibility check
val newChildren = checkKeyGroupCompatible(
parent.get, left, right, requiredChildDistributions)
if (newChildren.isDefined) {
children = newChildren.get
true
} else {
// If key group check fails, check ShufflePartitionIdPassThrough compatibility
checkShufflePartitionIdPassThroughCompatible(
left, right, requiredChildDistributions)
}
newChildren.isDefined
}

children = children.zip(requiredChildDistributions).zipWithIndex.map {
case ((child, _), idx) if isKeyGroupCompatible || !childrenIndexes.contains(idx) =>
case ((child, _), idx) if areChildrenCompatible ||
!childrenIndexes.contains(idx) =>
child
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
Expand Down Expand Up @@ -600,6 +607,23 @@ case class EnsureRequirements(
if (isCompatible) Some(Seq(newLeft, newRight)) else None
}

private def checkShufflePartitionIdPassThroughCompatible(
left: SparkPlan,
right: SparkPlan,
requiredChildDistribution: Seq[Distribution]): Boolean = {
(left.outputPartitioning, right.outputPartitioning) match {
case (p1: ShufflePartitionIdPassThrough, p2: ShufflePartitionIdPassThrough) =>
assert(requiredChildDistribution.length == 2)
val leftSpec = p1.createShuffleSpec(
requiredChildDistribution.head.asInstanceOf[ClusteredDistribution])
val rightSpec = p2.createShuffleSpec(
requiredChildDistribution(1).asInstanceOf[ClusteredDistribution])
leftSpec.isCompatibleWith(rightSpec)
case _ =>
false
}
}

// Similar to `OptimizeSkewedJoin.canSplitRightSide`
private def canReplicateLeftSide(joinType: JoinType): Boolean = {
joinType == Inner || joinType == Cross || joinType == RightOuter
Expand Down
Loading