From 779dc55bdc19a85da73c76f270c799a900c7084a Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Wed, 22 Apr 2026 09:00:50 -0700 Subject: [PATCH 01/12] WIP hash join reuse Signed-off-by: Rishi Chandra --- .../com/nvidia/spark/rapids/RapidsConf.scala | 11 + .../GpuBroadcastHashJoinExecBase.scala | 5 +- .../sql/rapids/execution/GpuHashJoin.scala | 510 +++++++++++++++--- 3 files changed, 460 insertions(+), 66 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 5cfc9cb09bf..cc7b8240a40 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -782,6 +782,15 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .checkValues(JoinBuildSideSelection.values.map(_.toString)) .createWithDefault(JoinBuildSideSelection.AUTO.toString) + val BROADCAST_HASH_TABLE_REUSE = + conf("spark.rapids.sql.join.broadcastHashTable.reuse") + .doc("Enable reuse of broadcast-side hash table state across stream batches for " + + "broadcast hash joins. This only applies when the broadcast side remains the " + + "physical build side selected by the join implementation.") + .internal() + .booleanConf + .createWithDefault(false) + val LOG_JOIN_CARDINALITY = conf("spark.rapids.sql.join.logCardinality") .doc("Enable logging of join cardinality statistics to help diagnose performance issues. " + "When enabled, logs task context, key data types, join condition, row counts, and " + @@ -3280,6 +3289,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val joinGathererSizeEstimateThreshold: Double = get(JOIN_GATHERER_SIZE_ESTIMATE_THRESHOLD) + lazy val broadcastHashTableReuse: Boolean = get(BROADCAST_HASH_TABLE_REUSE) + /** * Get join options based on the current configuration. * @param targetSize the target batch size in bytes to use for the join diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala index 9fad3be3153..52aa61154d5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala @@ -155,6 +155,7 @@ abstract class GpuBroadcastHashJoinExecBase( val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) val joinOptions = RapidsConf.getJoinOptions(conf, targetSize) + val enableBuildSideReuse = RapidsConf.BROADCAST_HASH_TABLE_REUSE.get(conf) val broadcastRelation = broadcastExchange.executeColumnarBroadcast[Any]() @@ -187,12 +188,12 @@ abstract class GpuBroadcastHashJoinExecBase( boundStreamKeys) } doJoin(builtBatch, nullFilteredStreamIter, joinOptions, numOutputRows, - numOutputBatches, opTime, joinTime) + numOutputBatches, opTime, joinTime, enableBuildSideReuse) } } else { // builtBatch will be closed in doJoin doJoin(builtBatch, streamIter, joinOptions, numOutputRows, numOutputBatches, opTime, - joinTime) + joinTime, enableBuildSideReuse) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index 2294d56b266..ae1b87656b9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -15,13 +15,13 @@ */ package org.apache.spark.sql.rapids.execution -import ai.rapids.cudf.{ColumnView, DType, GatherMap, NullEquality, OutOfBoundsPolicy, Scalar, Table} +import ai.rapids.cudf.{ColumnView, DType, GatherMap, HashJoin => CudfHashJoin, NullEquality, OutOfBoundsPolicy, Scalar, Table} import ai.rapids.cudf.ast.CompiledExpression import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRestoreOnRetry, withRetryNoSplit} -import com.nvidia.spark.rapids.jni.{GpuOOM, JoinPrimitives} +import com.nvidia.spark.rapids.jni.{DistinctHashJoin, GpuOOM, JoinPrimitives} import com.nvidia.spark.rapids.shims.ShimBinaryExecNode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression} @@ -299,6 +299,82 @@ object JoinImpl { GatherMapsResult.makeFromLeft(leftRet) } + def innerHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(rightKeys.innerJoinGatherMaps(leftHashJoin, _)) + .getOrElse(rightKeys.innerJoinGatherMaps(leftHashJoin)) + GatherMapsResult(arrayRet(1), arrayRet(0)) + } + + def innerHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(leftKeys.innerJoinGatherMaps(rightHashJoin, _)) + .getOrElse(leftKeys.innerJoinGatherMaps(rightHashJoin)) + GatherMapsResult(arrayRet(0), arrayRet(1)) + } + + def leftOuterHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(leftKeys.leftJoinGatherMaps(rightHashJoin, _)) + .getOrElse(leftKeys.leftJoinGatherMaps(rightHashJoin)) + GatherMapsResult(arrayRet(0), arrayRet(1)) + } + + def rightOuterHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(rightKeys.leftJoinGatherMaps(leftHashJoin, _)) + .getOrElse(rightKeys.leftJoinGatherMaps(leftHashJoin)) + GatherMapsResult(arrayRet(1), arrayRet(0)) + } + + def innerDistinctHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: DistinctHashJoin): GatherMapsResult = { + val arrayRet = leftHashJoin.innerJoinGatherMaps(rightKeys) + GatherMapsResult(arrayRet(1), arrayRet(0)) + } + + def innerDistinctHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: DistinctHashJoin): GatherMapsResult = { + val arrayRet = rightHashJoin.innerJoinGatherMaps(leftKeys) + GatherMapsResult(arrayRet(0), arrayRet(1)) + } + + def leftOuterDistinctHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: DistinctHashJoin): GatherMapsResult = { + val rightRet = rightHashJoin.leftJoinGatherMap(leftKeys) + GatherMapsResult.makeFromRight(rightRet) + } + + def rightOuterDistinctHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: DistinctHashJoin): GatherMapsResult = { + val leftRet = leftHashJoin.leftJoinGatherMap(rightKeys) + GatherMapsResult.makeFromLeft(leftRet) + } + + def innerHashJoinBuildLeftRowCount(rightKeys: Table, leftHashJoin: CudfHashJoin): Long = + rightKeys.innerJoinRowCount(leftHashJoin) + + def innerHashJoinBuildRightRowCount(leftKeys: Table, rightHashJoin: CudfHashJoin): Long = + leftKeys.innerJoinRowCount(rightHashJoin) + + def leftOuterHashJoinBuildRightRowCount(leftKeys: Table, rightHashJoin: CudfHashJoin): Long = + leftKeys.leftJoinRowCount(rightHashJoin) + + def rightOuterHashJoinBuildLeftRowCount(rightKeys: Table, leftHashJoin: CudfHashJoin): Long = + rightKeys.leftJoinRowCount(leftHashJoin) + /** * Do an inner hash join with the build table as the left table. * @param leftKeys the left equality join keys @@ -1046,6 +1122,8 @@ abstract class BaseHashJoinIterator( joinOptions: JoinOptions, joinType: JoinType, buildSide: GpuBuildSide, + enableBuildSideReuse: Boolean, + compareNullsEqual: Boolean, conditionForLogging: Option[Expression], opTime: GpuMetric, joinTime: GpuMetric) @@ -1061,7 +1139,7 @@ abstract class BaseHashJoinIterator( // We can cache this because the build side is not changing protected lazy val buildStats: JoinBuildSideStats = buildStatsOpt.getOrElse { joinType match { - case _: InnerLike | LeftOuter | RightOuter | FullOuter => + case _: InnerLike | LeftOuter | RightOuter | FullOuter | LeftSemi | LeftAnti => built.checkpoint() withRetryNoSplit { withRestoreOnRetry(built) { @@ -1074,6 +1152,71 @@ abstract class BaseHashJoinIterator( } } + private[this] var buildSideReuseDisabled = !enableBuildSideReuse + private[this] var cachedHashJoin: Option[CudfHashJoin] = None + private[this] var cachedDistinctHashJoin: Option[DistinctHashJoin] = None + + private def createCachedBuildSideReuseHandle[T <: AutoCloseable](factory: Table => T): T = { + built.checkpoint() + withRetryNoSplit { + withRestoreOnRetry(built) { + withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => + try { + withResource(GpuColumnVector.from(builtKeys)) { builtKeysTable => + factory(builtKeysTable) + } + } finally { + built.allowSpilling() + } + } + } + } + } + + protected def cachedHashJoinFor(expectedBuildSide: GpuBuildSide): Option[CudfHashJoin] = { + if (buildSideReuseDisabled || buildStats.isDistinct || expectedBuildSide != buildSide) { + None + } else { + if (cachedHashJoin.isEmpty) { + cachedHashJoin = Some( + createCachedBuildSideReuseHandle(buildKeys => new CudfHashJoin(buildKeys, compareNullsEqual))) + } + cachedHashJoin + } + } + + protected def cachedDistinctHashJoinFor(expectedBuildSide: GpuBuildSide): Option[DistinctHashJoin] = { + if (buildSideReuseDisabled || !buildStats.isDistinct || expectedBuildSide != buildSide) { + None + } else { + if (cachedDistinctHashJoin.isEmpty) { + cachedDistinctHashJoin = Some( + createCachedBuildSideReuseHandle(buildKeys => + new DistinctHashJoin(buildKeys, compareNullsEqual))) + } + cachedDistinctHashJoin + } + } + + protected def canReusePhysicalBuildSide( + expectedBuildSide: GpuBuildSide, + leftRowCount: Long, + rightRowCount: Long): Boolean = { + !buildSideReuseDisabled && + expectedBuildSide == buildSide && + JoinBuildSideSelection.selectPhysicalBuildSide( + joinOptions.buildSideSelection, expectedBuildSide, leftRowCount, rightRowCount) == + expectedBuildSide + } + + protected def disableBuildSideReuse(): Unit = { + buildSideReuseDisabled = true + cachedHashJoin.foreach(_.close()) + cachedHashJoin = None + cachedDistinctHashJoin.foreach(_.close()) + cachedDistinctHashJoin = None + } + /** * Check if sort join is supported for the given key expressions. * Sort join does not support ARRAY or STRUCT types in join keys. @@ -1281,7 +1424,7 @@ abstract class BaseHashJoinIterator( withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => // ensure that the build data can be spilled built.allowSpilling() - joinGatherer(builtKeys, built, streamBatch) + joinGatherer(builtKeys, built, streamBatch, numJoinRows) } } } @@ -1294,6 +1437,7 @@ abstract class BaseHashJoinIterator( || joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter => + disableBuildSideReuse() // Because this is just an estimate, it is possible for us to get this wrong, so // make sure we at least split the batch in half. val numBatches = Math.max(2, estimatedNumBatches(spillOnlyCb)) @@ -1317,16 +1461,18 @@ abstract class BaseHashJoinIterator( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] private def joinGathererLeftRight( leftKeys: ColumnarBatch, leftData: LazySpillableColumnarBatch, rightKeys: ColumnarBatch, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { withResource(GpuColumnVector.from(leftKeys)) { leftKeysTab => withResource(GpuColumnVector.from(rightKeys)) { rightKeysTab => - joinGathererLeftRight(leftKeysTab, leftData, rightKeysTab, rightData) + joinGathererLeftRight(leftKeysTab, leftData, rightKeysTab, rightData, numJoinRows) } } } @@ -1335,23 +1481,33 @@ abstract class BaseHashJoinIterator( buildKeys: ColumnarBatch, buildData: LazySpillableColumnarBatch, streamKeys: ColumnarBatch, - streamData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + streamData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { buildSide match { case GpuBuildLeft => - joinGathererLeftRight(buildKeys, buildData, streamKeys, streamData) + joinGathererLeftRight(buildKeys, buildData, streamKeys, streamData, numJoinRows) case GpuBuildRight => - joinGathererLeftRight(streamKeys, streamData, buildKeys, buildData) + joinGathererLeftRight(streamKeys, streamData, buildKeys, buildData, numJoinRows) } } private def joinGatherer( buildKeys: ColumnarBatch, buildData: LazySpillableColumnarBatch, - streamCb: LazySpillableColumnarBatch): Option[JoinGatherer] = { + streamCb: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { withResource(GpuProjectExec.project(streamCb.getBatch, boundStreamKeys)) { streamKeys => // ensure we make the stream side spillable again streamCb.allowSpilling() - joinGatherer(buildKeys, LazySpillableColumnarBatch.spillOnly(buildData), streamKeys, streamCb) + joinGatherer(buildKeys, LazySpillableColumnarBatch.spillOnly(buildData), streamKeys, streamCb, + numJoinRows) + } + } + + override def close(): Unit = { + if (!closed) { + disableBuildSideReuse() + super.close() } } @@ -1384,7 +1540,8 @@ class HashJoinIterator( val compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - private val joinTime: GpuMetric) + private val joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false) extends BaseHashJoinIterator( built, boundBuiltKeys, @@ -1395,14 +1552,41 @@ class HashJoinIterator( joinOptions, joinType, buildSide, + enableBuildSideReuse, + compareNullsEqual, conditionForLogging, opTime = opTime, joinTime = joinTime) { + + override def computeNumJoinRows(cb: LazySpillableColumnarBatch): Long = { + val fallback = super.computeNumJoinRows(cb) + if (buildStats.isDistinct) { + fallback + } else { + try { + withResource(GpuProjectExec.project(cb.getBatch, boundStreamKeys)) { streamKeys => + try { + withResource(GpuColumnVector.from(streamKeys)) { streamKeysTable => + exactNumJoinRows(streamKeysTable).getOrElse(fallback) + } + } finally { + cb.allowSpilling() + } + } + } catch { + case _: OutOfMemoryError | _: GpuOOM => + disableBuildSideReuse() + fallback + } + } + } + override protected def joinGathererLeftRight( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { NvtxIdWithMetrics(NvtxRegistry.HASH_JOIN_GATHER_MAP, joinTime) { // hack to work around unique_join not handling empty tables if (joinType.isInstanceOf[InnerLike] && @@ -1415,30 +1599,12 @@ class HashJoinIterator( val maps = if (buildStats.isDistinct) { // Distinct join optimizations (highest priority, overrides strategy) - logJoinCardinality(leftKeys, rightKeys, "distinct") - val result = joinType match { - case LeftOuter => - val rightRet = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual) - GatherMapsResult.makeFromRight(rightRet) - case RightOuter => - val leftRet = rightKeys.leftDistinctJoinGatherMap(leftKeys, compareNullsEqual) - GatherMapsResult.makeFromLeft(leftRet) - case _: InnerLike => - val arrayRet = if (buildSide == GpuBuildRight) { - leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual) - } else { - rightKeys.innerDistinctJoinGatherMaps(leftKeys, compareNullsEqual).reverse - } - GatherMapsResult(arrayRet(0), arrayRet(1)) - case _ => - // Fall through to strategy-based dispatching for non-outer joins - computeNonDistinctJoin(leftKeys, rightKeys, leftData, rightData) - } + val result = computeDistinctJoin(leftKeys, rightKeys) logJoinCompletion() result } else { // Non-distinct joins: use strategy-based dispatching - computeNonDistinctJoin(leftKeys, rightKeys, leftData, rightData) + computeNonDistinctJoin(leftKeys, rightKeys, numJoinRows) } makeGatherer(maps, leftData, rightData, joinType) @@ -1446,11 +1612,127 @@ class HashJoinIterator( } } + private def exactNumJoinRows(streamKeys: Table): Option[Long] = { + val (leftRowCount, rightRowCount) = buildSide match { + case GpuBuildRight => (streamKeys.getRowCount, built.numRows.toLong) + case GpuBuildLeft => (built.numRows.toLong, streamKeys.getRowCount) + } + + joinType match { + case _: InnerLike if canReusePhysicalBuildSide(buildSide, leftRowCount, rightRowCount) => + buildSide match { + case GpuBuildRight => + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.innerHashJoinBuildRightRowCount(streamKeys, hashJoin)) + case GpuBuildLeft => + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.innerHashJoinBuildLeftRowCount(streamKeys, hashJoin)) + } + case LeftOuter => + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.leftOuterHashJoinBuildRightRowCount(streamKeys, hashJoin)) + case RightOuter => + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeftRowCount(streamKeys, hashJoin)) + case _ => + None + } + } + + private def reusedGenericInnerJoin( + leftKeys: Table, + rightKeys: Table, + outputRowCount: Option[Long]): Option[GatherMapsResult] = { + if (!canReusePhysicalBuildSide(buildSide, leftKeys.getRowCount, rightKeys.getRowCount)) { + None + } else { + buildSide match { + case GpuBuildRight => + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin, outputRowCount)) + case GpuBuildLeft => + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin, outputRowCount)) + } + } + } + + private def reusedGenericLeftSemi( + leftKeys: Table): Option[GatherMapsResult] = { + cachedHashJoinFor(GpuBuildRight).map { hashJoin => + withResource(JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) { innerMaps => + JoinImpl.makeLeftSemi(innerMaps, leftKeys.getRowCount.toInt) + } + } + } + + private def reusedGenericLeftAnti( + leftKeys: Table): Option[GatherMapsResult] = { + cachedHashJoinFor(GpuBuildRight).map { hashJoin => + withResource(JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) { innerMaps => + JoinImpl.makeLeftAnti(innerMaps, leftKeys.getRowCount.toInt) + } + } + } + + private def computeDistinctJoin( + leftKeys: Table, + rightKeys: Table): GatherMapsResult = { + val reused = cachedDistinctHashJoinFor(buildSide).map { distinctHashJoin => + logJoinCardinality(leftKeys, rightKeys, "distinct (reused)") + joinType match { + case LeftOuter => + JoinImpl.leftOuterDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + case RightOuter => + JoinImpl.rightOuterDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + case LeftSemi => + withResource(JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin)) { innerMaps => + JoinImpl.makeLeftSemi(innerMaps, leftKeys.getRowCount.toInt) + } + case LeftAnti => + withResource(JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin)) { innerMaps => + JoinImpl.makeLeftAnti(innerMaps, leftKeys.getRowCount.toInt) + } + case _: InnerLike => + if (buildSide == GpuBuildRight) { + JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + } else { + JoinImpl.innerDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + } + case _ => + computeDistinctJoinWithoutReuse(leftKeys, rightKeys) + } + } + reused.getOrElse(computeDistinctJoinWithoutReuse(leftKeys, rightKeys)) + } + + private def computeDistinctJoinWithoutReuse( + leftKeys: Table, + rightKeys: Table): GatherMapsResult = { + logJoinCardinality(leftKeys, rightKeys, "distinct") + joinType match { + case LeftOuter => + val rightRet = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual) + GatherMapsResult.makeFromRight(rightRet) + case RightOuter => + val leftRet = rightKeys.leftDistinctJoinGatherMap(leftKeys, compareNullsEqual) + GatherMapsResult.makeFromLeft(leftRet) + case _: InnerLike => + val arrayRet = if (buildSide == GpuBuildRight) { + leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual) + } else { + rightKeys.innerDistinctJoinGatherMaps(leftKeys, compareNullsEqual).reverse + } + GatherMapsResult(arrayRet(0), arrayRet(1)) + case _ => + computeNonDistinctJoin(leftKeys, rightKeys, None) + } + } + private def computeNonDistinctJoin( leftKeys: Table, rightKeys: Table, - leftData: LazySpillableColumnarBatch, - rightData: LazySpillableColumnarBatch): GatherMapsResult = { + numJoinRows: Option[Long]): GatherMapsResult = { // Apply heuristics to select the effective strategy val effectiveStrategy = JoinStrategy.selectStrategy( joinOptions.strategy, @@ -1463,7 +1745,7 @@ class HashJoinIterator( effectiveStrategy match { case JoinStrategy.INNER_HASH_WITH_POST => // Use composable JNI APIs: inner join -> convert to target join type - computeNonCondInnerHashWithPost(leftKeys, rightKeys) + computeNonCondInnerHashWithPost(leftKeys, rightKeys, numJoinRows) case JoinStrategy.INNER_SORT_WITH_POST => // Check if sort join is supported (no ARRAY/STRUCT types) val leftKeysSupported = isSortJoinSupported(boundBuiltKeys) @@ -1477,17 +1759,18 @@ class HashJoinIterator( s"ARRAY or STRUCT types which are not supported for sort joins. " + s"Falling back to INNER_HASH_WITH_POST strategy.") } - computeNonCondInnerHashWithPost(leftKeys, rightKeys, isFallback = true) + computeNonCondInnerHashWithPost(leftKeys, rightKeys, numJoinRows, isFallback = true) } case _ => // Use existing hash join methods (for HASH_ONLY and when AUTO doesn't trigger heuristics) - computeWithHashJoin(leftKeys, rightKeys) + computeWithHashJoin(leftKeys, rightKeys, numJoinRows) } } private def computeNonCondInnerHashWithPost( leftKeys: Table, rightKeys: Table, + numJoinRows: Option[Long], isFallback: Boolean = false): GatherMapsResult = { val implName = if (isFallback) { "INNER_HASH_WITH_POST (fallback from INNER_SORT_WITH_POST)" @@ -1496,8 +1779,10 @@ class HashJoinIterator( } logJoinCardinality(leftKeys, rightKeys, implName) - val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + val innerMaps = reusedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, buildSide) + } val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1525,21 +1810,30 @@ class HashJoinIterator( private def computeWithHashJoin( leftKeys: Table, - rightKeys: Table): GatherMapsResult = { + rightKeys: Table, + numJoinRows: Option[Long]): GatherMapsResult = { logJoinCardinality(leftKeys, rightKeys, "hash join") val result = joinType match { case LeftOuter => - JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin, numJoinRows)) + .getOrElse(JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case RightOuter => - JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual) + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin, numJoinRows)) + .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case _: InnerLike => - JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + reusedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, buildSide) + } case LeftSemi => - JoinImpl.leftSemiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + reusedGenericLeftSemi(leftKeys) + .getOrElse(JoinImpl.leftSemiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case LeftAnti => - JoinImpl.leftAntiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + reusedGenericLeftAnti(leftKeys) + .getOrElse(JoinImpl.leftAntiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case _ => throw new NotImplementedError(s"Join Type ${joinType.getClass} is not currently" + s" supported") @@ -1567,7 +1861,8 @@ class ConditionalHashJoinIterator( compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - joinTime: GpuMetric) + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false) extends BaseHashJoinIterator( built, boundBuiltKeys, @@ -1578,6 +1873,8 @@ class ConditionalHashJoinIterator( joinOptions, joinType, buildSide, + enableBuildSideReuse, + compareNullsEqual, conditionForLogging, opTime = opTime, joinTime = joinTime) { @@ -1592,7 +1889,8 @@ class ConditionalHashJoinIterator( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { val nullEquality = if (compareNullsEqual) NullEquality.EQUAL else NullEquality.UNEQUAL NvtxIdWithMetrics(NvtxRegistry.HASH_JOIN_GATHER_MAP, joinTime) { withResource(GpuColumnVector.from(leftData.getBatch)) { leftTable => @@ -1783,7 +2081,8 @@ class HashJoinStreamSideIterator( compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - joinTime: GpuMetric) + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false) extends BaseHashJoinIterator( built, boundBuiltKeys, @@ -1794,6 +2093,8 @@ class HashJoinStreamSideIterator( joinOptions, joinType, buildSide, + enableBuildSideReuse, + compareNullsEqual, conditionForLogging, opTime = opTime, joinTime = joinTime) { @@ -1828,11 +2129,79 @@ class HashJoinStreamSideIterator( private[this] var builtSideTracker: Option[SpillableColumnarBatch] = buildSideTrackerInit + private def reusedUnconditionalInnerHashJoin( + leftKeys: Table, + rightKeys: Table): Option[GatherMapsResult] = { + if (!canReusePhysicalBuildSide(cudfBuildSide, leftKeys.getRowCount, rightKeys.getRowCount)) { + None + } else { + cudfBuildSide match { + case GpuBuildRight => + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) + case GpuBuildLeft => + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin)) + } + } + } + + private def computeDistinctUnconditionalJoin( + leftKeys: Table, + rightKeys: Table, + originalJoinType: Option[JoinType]): GatherMapsResult = { + val implName = cachedDistinctHashJoinFor(cudfBuildSide) + .map(_ => s"distinct (outer: $joinType, reused)") + .getOrElse(s"distinct (outer: $joinType)") + logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) + + val result = cachedDistinctHashJoinFor(cudfBuildSide).map { distinctHashJoin => + subJoinType match { + case LeftOuter => + JoinImpl.leftOuterDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + case RightOuter => + JoinImpl.rightOuterDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + case Inner => + if (cudfBuildSide == GpuBuildRight) { + JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + } else { + JoinImpl.innerDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + } + case t => + throw new IllegalStateException(s"unsupported join type: $t") + } + }.getOrElse { + subJoinType match { + case LeftOuter => + val rightRet = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual) + GatherMapsResult.makeFromRight(rightRet) + case RightOuter => + val leftRet = rightKeys.leftDistinctJoinGatherMap(leftKeys, compareNullsEqual) + GatherMapsResult.makeFromLeft(leftRet) + case Inner => + val arrayRet = if (cudfBuildSide == GpuBuildRight) { + leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual) + } else { + rightKeys.innerDistinctJoinGatherMaps(leftKeys, compareNullsEqual).reverse + } + GatherMapsResult(arrayRet(0), arrayRet(1)) + case t => + throw new IllegalStateException(s"unsupported join type: $t") + } + } + logJoinCompletion() + result + } + private def unconditionalJoinGatherMaps( leftKeys: Table, rightKeys: Table): GatherMapsResult = { // Pass the original joinType if it was transformed to subJoinType val originalJoinType = if (joinType != subJoinType) Some(joinType) else None + if (buildStats.isDistinct) { + return computeDistinctUnconditionalJoin(leftKeys, rightKeys, originalJoinType) + } + // Apply heuristics to select the effective strategy for unconditional joins // Note: subJoinType is used for strategy selection since that's what we're actually executing val effectiveStrategy = JoinStrategy.selectStrategy( @@ -1879,8 +2248,10 @@ class HashJoinStreamSideIterator( } logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) - val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + val innerMaps = reusedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, cudfBuildSide) + } val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1918,12 +2289,18 @@ class HashJoinStreamSideIterator( val result = subJoinType match { case LeftOuter => - JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin)) + .getOrElse(JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case RightOuter => - JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual) + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin)) + .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case Inner => - JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + reusedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, cudfBuildSide) + } case t => throw new IllegalStateException(s"unsupported join type: $t") } @@ -2085,7 +2462,8 @@ class HashJoinStreamSideIterator( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { NvtxIdWithMetrics(NvtxRegistry.FULL_HASH_JOIN_GATHER_MAP, joinTime) { val maps = lazyCompiledCondition.map { lazyCondition => conditionalJoinGatherMaps(leftKeys, leftData, rightKeys, rightData, lazyCondition) @@ -2242,12 +2620,14 @@ class HashOuterJoinIterator( compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - joinTime: GpuMetric) extends Iterator[ColumnarBatch] with TaskAutoCloseableResource { + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false) extends Iterator[ColumnarBatch] with TaskAutoCloseableResource { private val streamJoinIter = new HashJoinStreamSideIterator(joinType, built, boundBuiltKeys, buildStats, buildSideTrackerInit, stream, boundStreamKeys, streamAttributes, lazyCompiledCondition, - joinOptions, buildSide, compareNullsEqual, conditionForLogging, opTime, joinTime) + joinOptions, buildSide, compareNullsEqual, conditionForLogging, opTime, joinTime, + enableBuildSideReuse) private var finalBatch: Option[ColumnarBatch] = None @@ -2551,7 +2931,8 @@ trait GpuHashJoin extends GpuJoinExec { numOutputRows: GpuMetric, numOutputBatches: GpuMetric, opTime: GpuMetric, - joinTime: GpuMetric): Iterator[ColumnarBatch] = { + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false): Iterator[ColumnarBatch] = { val filterOutNull = GpuHashJoin.buildSideNeedsNullFilter(joinType, compareNullsEqual, buildSide, buildKeys) @@ -2600,7 +2981,7 @@ trait GpuHashJoin extends GpuJoinExec { new HashOuterJoinIterator(joinType, spillableBuiltBatch, boundBuildKeys, None, None, lazyStream, boundStreamKeys, streamedPlan.output, lazyCond, joinOptions, buildSide, - compareNullsEqual, condition, opTime, joinTime) + compareNullsEqual, condition, opTime, joinTime, enableBuildSideReuse) case _ => if (boundConditionLeftRight.isDefined) { // ConditionalHashJoinIterator will close the LazyCompiledCondition @@ -2612,11 +2993,12 @@ trait GpuHashJoin extends GpuJoinExec { new ConditionalHashJoinIterator(spillableBuiltBatch, boundBuildKeys, None, lazyStream, boundStreamKeys, streamedPlan.output, lazyCond, joinOptions, joinType, buildSide, - compareNullsEqual, condition, opTime, joinTime) + compareNullsEqual, condition, opTime, joinTime, enableBuildSideReuse) } else { new HashJoinIterator(spillableBuiltBatch, boundBuildKeys, None, lazyStream, boundStreamKeys, streamedPlan.output, joinOptions, - joinType, buildSide, compareNullsEqual, condition, opTime, joinTime) + joinType, buildSide, compareNullsEqual, condition, opTime, joinTime, + enableBuildSideReuse) } } From db7a6711a0cb57f16e2c400ba0b73a314f289458 Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Wed, 22 Apr 2026 10:22:06 -0700 Subject: [PATCH 02/12] Add broadcast hash join reuse coverage Signed-off-by: Rishi Chandra --- .../spark/rapids/BroadcastHashJoinSuite.scala | 79 ++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index 70fbf438f1d..8d52f2b57ed 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,31 @@ package com.nvidia.spark.rapids import com.nvidia.spark.rapids.TestUtils.findOperator import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuHashJoin} class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { + private def broadcastReuseConf: SparkConf = new SparkConf() + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.rapids.sql.join.broadcastHashTable.reuse", "true") + .set("spark.rapids.sql.batchSizeBytes", "1") + + private def streamedProbeDf(spark: SparkSession): DataFrame = + spark.range(0, 128).selectExpr( + "CAST(id % 8 AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + + private def distinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(id AS INT) AS join_key", + "CAST(id * 10 AS INT) AS build_value") + + private def nonDistinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 16).selectExpr( + "CAST(id % 4 AS INT) AS join_key", + "CAST(id AS INT) AS build_value") test("broadcast hint isn't propagated after a join") { val conf = new SparkConf() @@ -71,4 +92,60 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { } }) } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct inner build right", + streamedProbeDf, + distinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct left outer build right", + streamedProbeDf, + distinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "left") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct right outer build left", + distinctBuildDf, + streamedProbeDf, + conf = broadcastReuseConf) { + (build, probe) => broadcast(build).join(probe, Seq("join_key"), "right") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct inner build right", + streamedProbeDf, + nonDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct inner build left", + nonDistinctBuildDf, + streamedProbeDf, + conf = broadcastReuseConf) { + (build, probe) => broadcast(build).join(probe, Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct left semi build right", + streamedProbeDf, + nonDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "leftsemi") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct left anti build right", + streamedProbeDf, + nonDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "leftanti") + } } From 1b50e12fe5042ba2d149dfdd0a8719186627a309 Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Wed, 22 Apr 2026 11:30:22 -0700 Subject: [PATCH 03/12] Fix hash join reuse review issues Signed-off-by: Rishi Chandra --- .../rapids/GpuShuffledHashJoinExec.scala | 3 +- .../sql/rapids/execution/GpuHashJoin.scala | 28 ++++++---- .../execution/GpuSubPartitionHashJoin.scala | 2 +- .../execution/GpuBroadcastHashJoinExec.scala | 6 +-- .../spark/rapids/BroadcastHashJoinSuite.scala | 53 +++++++++++++++++++ .../com/nvidia/spark/rapids/JoinsSuite.scala | 53 +++++++++++++++++++ 6 files changed, 128 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala index e7bf11febf7..76e0fe3842b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala @@ -264,7 +264,7 @@ case class GpuShuffledHashJoinExec( } // doJoin will close singleBatch doJoin(singleBatch, maybeBufferedStreamIter, joinOptions, - numOutputRows, numOutputBatches, opTime, joinTime) + numOutputRows, numOutputBatches, opTime, joinTime, enableBuildSideReuse = false) case Right(builtBatchIter) => // For big joins, when the build data can not fit into a single batch. val sizeBuildIter = builtBatchIter.map { cb => @@ -542,4 +542,3 @@ object GpuShuffledHashJoinExec extends Logging { retIter } } - diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index ae1b87656b9..8d067ae1aff 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -1211,10 +1211,11 @@ abstract class BaseHashJoinIterator( protected def disableBuildSideReuse(): Unit = { buildSideReuseDisabled = true - cachedHashJoin.foreach(_.close()) + val hashJoinToClose = cachedHashJoin cachedHashJoin = None - cachedDistinctHashJoin.foreach(_.close()) + val distinctHashJoinToClose = cachedDistinctHashJoin cachedDistinctHashJoin = None + Seq(hashJoinToClose, distinctHashJoinToClose).flatten.safeClose() } /** @@ -1559,24 +1560,27 @@ class HashJoinIterator( joinTime = joinTime) { override def computeNumJoinRows(cb: LazySpillableColumnarBatch): Long = { - val fallback = super.computeNumJoinRows(cb) + lazy val fallback = super.computeNumJoinRows(cb) if (buildStats.isDistinct) { fallback } else { + cb.checkpoint() try { - withResource(GpuProjectExec.project(cb.getBatch, boundStreamKeys)) { streamKeys => - try { - withResource(GpuColumnVector.from(streamKeys)) { streamKeysTable => - exactNumJoinRows(streamKeysTable).getOrElse(fallback) + withRetryNoSplit { + withRestoreOnRetry(cb) { + withResource(GpuProjectExec.project(cb.getBatch, boundStreamKeys)) { streamKeys => + withResource(GpuColumnVector.from(streamKeys)) { streamKeysTable => + exactNumJoinRows(streamKeysTable).getOrElse(fallback) + } } - } finally { - cb.allowSpilling() } } } catch { case _: OutOfMemoryError | _: GpuOOM => disableBuildSideReuse() fallback + } finally { + cb.allowSpilling() } } } @@ -2198,7 +2202,9 @@ class HashJoinStreamSideIterator( // Pass the original joinType if it was transformed to subJoinType val originalJoinType = if (joinType != subJoinType) Some(joinType) else None - if (buildStats.isDistinct) { + // The distinct outer path only produces a single gather map for LeftOuter/RightOuter, + // but HashJoinStreamSideIterator always needs both maps to update tracking state. + if (buildStats.isDistinct && subJoinType == Inner) { return computeDistinctUnconditionalJoin(leftKeys, rightKeys, originalJoinType) } @@ -2932,7 +2938,7 @@ trait GpuHashJoin extends GpuJoinExec { numOutputBatches: GpuMetric, opTime: GpuMetric, joinTime: GpuMetric, - enableBuildSideReuse: Boolean = false): Iterator[ColumnarBatch] = { + enableBuildSideReuse: Boolean): Iterator[ColumnarBatch] = { val filterOutNull = GpuHashJoin.buildSideNeedsNullFilter(joinType, compareNullsEqual, buildSide, buildKeys) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index 12faa151a4b..42d96f04ea9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -607,7 +607,7 @@ trait GpuSubPartitionHashJoin extends Logging { self: GpuHashJoin => } // Leverage the original join iterators val joinIter = doJoin(buildCb, streamIter, joinOptions, - numOutputRows, numOutputBatches, opTime, joinTime) + numOutputRows, numOutputBatches, opTime, joinTime, enableBuildSideReuse = false) Some(joinIter) } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 395c5b19bfc..52a948c6f0b 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -164,6 +164,7 @@ case class GpuBroadcastHashJoinExec( val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) val joinOptions = RapidsConf.getJoinOptions(conf, targetSize) + val enableBuildSideReuse = RapidsConf.BROADCAST_HASH_TABLE_REUSE.get(conf) // Get all the broadcast data from the shuffle coalesced into a single partition val partitionSpecs = Seq(CoalescedPartitionSpec(0, shuffleExchange.numPartitions)) @@ -202,12 +203,12 @@ case class GpuBroadcastHashJoinExec( boundStreamKeys) } doJoin(builtBatch, nullFilteredStreamIter, joinOptions, numOutputRows, - numOutputBatches, opTime, joinTime) + numOutputBatches, opTime, joinTime, enableBuildSideReuse) } } else { // builtBatch will be closed in doJoin doJoin(builtBatch, streamIter, joinOptions, numOutputRows, numOutputBatches, opTime, - joinTime) + joinTime, enableBuildSideReuse) } } } @@ -220,4 +221,3 @@ case class GpuBroadcastHashJoinExec( } } } - diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index 8d52f2b57ed..abdce14eaf8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -45,6 +45,32 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { "CAST(id % 4 AS INT) AS join_key", "CAST(id AS INT) AS build_value") + private def nullableProbeDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 3 " + + "WHEN 5 THEN 4 " + + "WHEN 6 THEN 5 " + + "ELSE 8 END AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + + private def nullableDistinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 3 " + + "WHEN 5 THEN 4 " + + "WHEN 6 THEN 6 " + + "ELSE 9 END AS INT) AS join_key", + "CAST(id * 10 AS INT) AS build_value") + test("broadcast hint isn't propagated after a join") { val conf = new SparkConf() .set("spark.sql.autoBroadcastJoinThreshold", "-1") @@ -148,4 +174,31 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { conf = broadcastReuseConf) { (probe, build) => probe.join(broadcast(build), Seq("join_key"), "leftanti") } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct inner nullable keys build right", + nullableProbeDf, + nullableDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct left outer nullable keys build right", + nullableProbeDf, + nullableDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "left") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse conditional inner build right", + streamedProbeDf, + distinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => + probe.join(broadcast(build), + probe("join_key") === build("join_key") && + probe("probe_value") <= build("build_value"), "inner") + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala index 2f69cf6ca61..d04bb4c91b8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint} @@ -45,6 +46,39 @@ class JoinsSuite extends SparkQueryCompareTestSuite { .set("spark.sql.join.preferSortMergeJoin", "false") .set("spark.sql.shuffle.partitions", "2") // hack to try and work around bug in cudf + private def shuffledDistinctOuterConf: SparkConf = new SparkConf() + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.sql.join.preferSortMergeJoin", "false") + .set("spark.sql.shuffle.partitions", "2") + .set("spark.rapids.sql.batchSizeBytes", "1") + + private def distinctNullableLeftDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 3 " + + "WHEN 5 THEN 4 " + + "WHEN 6 THEN 6 " + + "ELSE 8 END AS INT) AS join_key", + "CAST(id AS INT) AS left_value") + + private def distinctNullableRightDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 4 " + + "WHEN 5 THEN 5 " + + "WHEN 6 THEN 7 " + + "ELSE 9 END AS INT) AS join_key", + "CAST(id * 10 AS INT) AS right_value") + IGNORE_ORDER_testSparkResultsAreEqual2("Test hash join", longsDf, biggerLongsDf, conf = shuffledJoinConf) { (A, B) => A.join(B, A("longs") === B("longs")) @@ -70,6 +104,25 @@ class JoinsSuite extends SparkQueryCompareTestSuite { (A, B) => A.join(B, A("longs") === B("longs"), "FullOuter") } + IGNORE_ORDER_testSparkResultsAreEqual2( + "Test shuffled distinct full join with nullable keys", + distinctNullableLeftDf, + distinctNullableRightDf, + conf = shuffledDistinctOuterConf) { + (left, right) => left.repartition(2).join(right.repartition(2), Seq("join_key"), "FullOuter") + } + + test("Test shuffled distinct full join with nullable keys uses sized hash join") { + withGpuSparkSession(spark => { + val left = distinctNullableLeftDf(spark).repartition(2) + val right = distinctNullableRightDf(spark).repartition(2) + val joined = left.join(right, Seq("join_key"), "FullOuter") + joined.collect() + assert(TestUtils.findOperator(joined.queryExecution.executedPlan, + _.isInstanceOf[GpuShuffledSizedHashJoinExec[_]]).isDefined) + }, shuffledDistinctOuterConf) + } + IGNORE_ORDER_testSparkResultsAreEqual2("Test cross join", longsDf, biggerLongsDf, conf = shuffledJoinConf) { (A, B) => A.join(B.hint("broadcast"), A("longs") < B("longs"), "Cross") From 1e2304ac134555a65c3dbdf78e86ab54d97e9eb2 Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Wed, 22 Apr 2026 15:16:57 -0700 Subject: [PATCH 04/12] We don't need to switch join sides when reuse is enabled --- .../sql/rapids/execution/GpuHashJoin.scala | 66 ++++++------------- .../spark/rapids/BroadcastHashJoinSuite.scala | 13 ++++ 2 files changed, 34 insertions(+), 45 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index 8d067ae1aff..ed7d01ebda6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -1198,17 +1198,6 @@ abstract class BaseHashJoinIterator( } } - protected def canReusePhysicalBuildSide( - expectedBuildSide: GpuBuildSide, - leftRowCount: Long, - rightRowCount: Long): Boolean = { - !buildSideReuseDisabled && - expectedBuildSide == buildSide && - JoinBuildSideSelection.selectPhysicalBuildSide( - joinOptions.buildSideSelection, expectedBuildSide, leftRowCount, rightRowCount) == - expectedBuildSide - } - protected def disableBuildSideReuse(): Unit = { buildSideReuseDisabled = true val hashJoinToClose = cachedHashJoin @@ -1617,13 +1606,8 @@ class HashJoinIterator( } private def exactNumJoinRows(streamKeys: Table): Option[Long] = { - val (leftRowCount, rightRowCount) = buildSide match { - case GpuBuildRight => (streamKeys.getRowCount, built.numRows.toLong) - case GpuBuildLeft => (built.numRows.toLong, streamKeys.getRowCount) - } - joinType match { - case _: InnerLike if canReusePhysicalBuildSide(buildSide, leftRowCount, rightRowCount) => + case _: InnerLike => buildSide match { case GpuBuildRight => cachedHashJoinFor(GpuBuildRight) @@ -1643,21 +1627,17 @@ class HashJoinIterator( } } - private def reusedGenericInnerJoin( + private def cachedGenericInnerJoin( leftKeys: Table, rightKeys: Table, outputRowCount: Option[Long]): Option[GatherMapsResult] = { - if (!canReusePhysicalBuildSide(buildSide, leftKeys.getRowCount, rightKeys.getRowCount)) { - None - } else { - buildSide match { - case GpuBuildRight => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin, outputRowCount)) - case GpuBuildLeft => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin, outputRowCount)) - } + buildSide match { + case GpuBuildRight => + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin, outputRowCount)) + case GpuBuildLeft => + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin, outputRowCount)) } } @@ -1783,7 +1763,7 @@ class HashJoinIterator( } logJoinCardinality(leftKeys, rightKeys, implName) - val innerMaps = reusedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { + val innerMaps = cachedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, joinOptions.buildSideSelection, buildSide) } @@ -1828,7 +1808,7 @@ class HashJoinIterator( .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin, numJoinRows)) .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case _: InnerLike => - reusedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { + cachedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, joinOptions.buildSideSelection, buildSide) } @@ -2133,20 +2113,16 @@ class HashJoinStreamSideIterator( private[this] var builtSideTracker: Option[SpillableColumnarBatch] = buildSideTrackerInit - private def reusedUnconditionalInnerHashJoin( + private def cachedUnconditionalInnerHashJoin( leftKeys: Table, rightKeys: Table): Option[GatherMapsResult] = { - if (!canReusePhysicalBuildSide(cudfBuildSide, leftKeys.getRowCount, rightKeys.getRowCount)) { - None - } else { - cudfBuildSide match { - case GpuBuildRight => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) - case GpuBuildLeft => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin)) - } + cudfBuildSide match { + case GpuBuildRight => + cachedHashJoinFor(GpuBuildRight) + .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) + case GpuBuildLeft => + cachedHashJoinFor(GpuBuildLeft) + .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin)) } } @@ -2254,7 +2230,7 @@ class HashJoinStreamSideIterator( } logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) - val innerMaps = reusedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { + val innerMaps = cachedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, joinOptions.buildSideSelection, cudfBuildSide) } @@ -2303,7 +2279,7 @@ class HashJoinStreamSideIterator( .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin)) .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case Inner => - reusedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { + cachedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, joinOptions.buildSideSelection, cudfBuildSide) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index abdce14eaf8..93b23df34b5 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -45,6 +45,11 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { "CAST(id % 4 AS INT) AS join_key", "CAST(id AS INT) AS build_value") + private def largerNonDistinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 1024).selectExpr( + "CAST(id % 4 AS INT) AS join_key", + "CAST(id AS INT) AS build_value") + private def nullableProbeDf(spark: SparkSession): DataFrame = spark.range(0, 8).selectExpr( "CAST(CASE CAST(id AS INT) " + @@ -151,6 +156,14 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") } + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct inner build right with auto build-side selection", + streamedProbeDf, + largerNonDistinctBuildDf, + conf = broadcastReuseConf.clone().set("spark.rapids.sql.join.buildSide", "AUTO")) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + IGNORE_ORDER_testSparkResultsAreEqual2( "broadcast hash join reuse non-distinct inner build left", nonDistinctBuildDf, From bccf35c7b71a78c75bf3dda466b5b4c06539cadd Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Thu, 23 Apr 2026 13:55:15 -0700 Subject: [PATCH 05/12] Recomputable spill handle, tests --- integration_tests/src/main/python/asserts.py | 1 + .../src/main/python/join_test.py | 84 ++++++- .../com/nvidia/spark/rapids/GpuMetrics.scala | 4 + .../spark/rapids/spill/SpillFramework.scala | 184 ++++++++++++++- .../rapids/ExecutionPlanCaptureCallback.scala | 10 +- ...mmedExecutionPlanCaptureCallbackImpl.scala | 35 ++- .../execution/BroadcastCachedBuildSide.scala | 128 ++++++++++ .../execution/GpuBroadcastExchangeExec.scala | 35 +++ .../GpuBroadcastHashJoinExecBase.scala | 22 +- .../sql/rapids/execution/GpuHashJoin.scala | 222 +++++++++++------- .../execution/GpuBroadcastHashJoinExec.scala | 2 + .../spark/rapids/BroadcastHashJoinSuite.scala | 29 +++ .../SharedRecomputableDeviceHandleSuite.scala | 161 +++++++++++++ 13 files changed, 818 insertions(+), 99 deletions(-) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index cc9013cd845..5d3aa3f1fba 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -528,6 +528,7 @@ def assert_cpu_and_gpu_are_equal_collect_with_capture(func, _sort_locally(from_cpu, from_gpu) assert_equal(from_cpu, from_gpu) + return gpu_df def assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun, sql, diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 500d49f9b81..a1aa7f93236 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -19,7 +19,7 @@ from asserts import (assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_row_counts_equal, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture, assert_cpu_and_gpu_are_equal_sql_with_capture, assert_gpu_and_cpu_are_equal_sql) -from conftest import is_emr_runtime +from conftest import is_emr_runtime, spark_jvm from data_gen import * from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, disable_ansi_mode from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime, is_spark_400_or_later, is_spark_411_or_later @@ -217,6 +217,88 @@ def do_join(spark): }) +@ignore_order(local=True) +@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) +def test_broadcast_hash_join_reuse_large_build_aqe(kudo_enabled): + def do_join(spark): + left = spark.range(4096, numPartitions=4).selectExpr( + "CAST(id % 64 AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + right = spark.range(8192).selectExpr( + "CAST(id % 64 AS INT) AS join_key", + "CAST(id * 3 AS INT) AS build_value") + return left.join(broadcast(right), "join_key", "inner").groupBy("join_key").count() + + conf = { + 'spark.sql.adaptive.enabled': 'true', + 'spark.sql.autoBroadcastJoinThreshold': '-1', + 'spark.sql.shuffle.partitions': '4', + 'spark.rapids.sql.batchSizeBytes': '1024', + 'spark.rapids.sql.join.broadcastHashTable.reuse': 'true', + kudo_enabled_conf_key: kudo_enabled + } + gpu_df = assert_cpu_and_gpu_are_equal_collect_with_capture( + do_join, + exist_classes='GpuBroadcastHashJoinExec', + conf=conf) + jvm = spark_jvm() + cache_builds = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheBuilds') + cache_hits = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheHits') + assert cache_builds == 1 + if gpu_df.sparkSession.sparkContext.master.startswith("local"): + assert cache_hits >= 3 + else: + assert cache_hits > 0 + + +@ignore_order(local=True) +@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) +def test_broadcast_hash_join_reuse_same_broadcast_multi_join_aqe(kudo_enabled): + def do_join(spark): + fact = spark.range(512, numPartitions=4).selectExpr( + "CAST(id % 8 AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + dim = broadcast(spark.range(32).selectExpr( + "CAST(id % 8 AS INT) AS join_key", + "CAST(id AS INT) AS build_value")) + return fact.join(dim, "join_key", "inner").select("join_key", "probe_value") \ + .join(dim, "join_key", "inner").groupBy("join_key").count() + + conf = { + 'spark.sql.adaptive.enabled': 'true', + 'spark.sql.autoBroadcastJoinThreshold': '-1', + 'spark.sql.exchange.reuse': 'true', + 'spark.sql.shuffle.partitions': '4', + 'spark.rapids.sql.batchSizeBytes': '1024', + 'spark.rapids.sql.join.broadcastHashTable.reuse': 'true', + kudo_enabled_conf_key: kudo_enabled + } + gpu_df = assert_cpu_and_gpu_are_equal_collect_with_capture( + do_join, + exist_classes='GpuBroadcastHashJoinExec,ReusedExchangeExec', + conf=conf) + jvm = spark_jvm() + cache_builds = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheBuilds') + cache_hits = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheHits') + assert cache_builds == 1 + if gpu_df.sparkSession.sparkContext.master.startswith("local"): + assert cache_hits >= 3 + else: + assert cache_hits > 0 + + # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala index b32184d08ce..59d45f89dec 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala @@ -116,6 +116,8 @@ object GpuMetric extends Logging { val BUILD_DATA_SIZE = "buildDataSize" val BUILD_TIME = "buildTime" val STREAM_TIME = "streamTime" + val BUILD_SIDE_CACHE_BUILDS = "buildSideCacheBuilds" + val BUILD_SIDE_CACHE_HITS = "buildSideCacheHits" val NUM_TASKS_FALL_BACKED = "numTasksFallBacked" val NUM_TASKS_REPARTITIONED = "numTasksRepartitioned" val NUM_TASKS_SKIPPED_AGG = "numTasksSkippedAgg" @@ -173,6 +175,8 @@ object GpuMetric extends Logging { val DESCRIPTION_BUILD_DATA_SIZE = "build side size" val DESCRIPTION_BUILD_TIME = "build time" val DESCRIPTION_STREAM_TIME = "stream time" + val DESCRIPTION_BUILD_SIDE_CACHE_BUILDS = "cached build side builds" + val DESCRIPTION_BUILD_SIDE_CACHE_HITS = "cached build side hits" val DESCRIPTION_NUM_TASKS_FALL_BACKED = "number of sort fallback tasks" val DESCRIPTION_NUM_TASKS_REPARTITIONED = "number of tasks repartitioned for agg" val DESCRIPTION_NUM_TASKS_SKIPPED_AGG = "number of tasks skipped aggregation" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 2a05e486e9a..79ae4bad0d9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -278,11 +278,18 @@ trait SpillableHandle extends StoreHandle with Logging { } /** - * Spillable handles that can be materialized on the device. + * Contract for handles tracked by the device spill store (see [[SpillableHandle]]). + */ +trait DeviceStoreHandle extends SpillableHandle { + def releaseSpilled(): Unit +} + +/** + * Spillable handles that can be materialized on the device and spilled to host. * @tparam T an auto closeable subclass. `dev` tracks an instance of this object, * on the device. */ -trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { +trait DeviceSpillableHandle[T <: AutoCloseable] extends DeviceStoreHandle { private[spill] var dev: Option[T] private[spill] override def spillable: Boolean = synchronized { @@ -305,7 +312,7 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { * free a device buffer that the worker thread isn't done with). * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. */ - def releaseSpilled(): Unit = { + override def releaseSpilled(): Unit = { releaseDeviceResource() } @@ -319,6 +326,173 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } } +object SharedRecomputableDeviceHandle { + final class Lease[T <: AutoCloseable] private[spill] ( + handle: SharedRecomputableDeviceHandle[T], + val resource: T) extends AutoCloseable { + private[this] var closed = false + + override def close(): Unit = synchronized { + if (closed) { + throw new IllegalStateException("Close called too many times on recomputable handle lease") + } + closed = true + handle.releasePin() + } + } + + def apply[T <: AutoCloseable]( + approxSizeInBytes: Long, + initialValue: T)( + rebuild: => T): SharedRecomputableDeviceHandle[T] = { + val handle = new SharedRecomputableDeviceHandle(approxSizeInBytes, initialValue, () => rebuild) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +/** + * Spill-framework handle for device-only state that is cheaper to recompute than to spill. + * + * When this handle is selected for spilling, it does not copy anything to host or disk. Instead + * it marks the current device state as evicted and returns `approxSizeInBytes` so the spill + * framework accounts for the freed device memory. The actual close of the evicted state is + * deferred to `releaseSpilled`, after device synchronization has completed. + * + * The protected device state does not expose reference counts, so spillability cannot follow the + * usual `getRefCount == 1` pattern used by cuDF tables and buffers. Instead, callers must pin the + * state through `acquire`, and spillability is derived from an application-level pin count. + */ +class SharedRecomputableDeviceHandle[T <: AutoCloseable] private[spill] ( + override val approxSizeInBytes: Long, + initialValue: T, + rebuild: () => T) extends DeviceStoreHandle with Logging { + import SharedRecomputableDeviceHandle.Lease + + private[spill] var dev: Option[T] = Some(initialValue) + private[this] var pendingRelease: Seq[T] = Seq.empty + private[this] var pinCount: Int = 0 + private[this] var rebuilding: Boolean = false + + private[spill] override def spillable: Boolean = synchronized { + super.spillable && dev.isDefined && pinCount == 0 + } + + def acquire(): Lease[T] = { + var materialized: Option[T] = None + var shouldBuild = false + while (materialized.isEmpty) { + shouldBuild = synchronized { + if (closed) { + throw new IllegalStateException("attempting to materialize a closed handle") + } else if (dev.isDefined) { + pinCount += 1 + materialized = dev + false + } else if (rebuilding) { + wait() + false + } else { + rebuilding = true + true + } + } + + if (shouldBuild) { + var rebuilt: Option[T] = None + try { + rebuilt = Some(rebuild()) + var shouldTrack = false + synchronized { + rebuilding = false + if (closed) { + notifyAll() + throw new IllegalStateException("attempting to materialize a closed handle") + } + dev = rebuilt + pinCount += 1 + materialized = rebuilt + shouldTrack = true + notifyAll() + } + if (shouldTrack) { + SpillFramework.stores.deviceStore.track(this) + } + } catch { + case t: Throwable => + rebuilt.foreach(_.close()) + synchronized { + rebuilding = false + notifyAll() + } + throw t + } + } + } + new Lease(this, materialized.get) + } + + private[spill] def releasePin(): Unit = synchronized { + if (pinCount <= 0) { + throw new IllegalStateException("releasePin called without a matching acquire") + } + pinCount -= 1 + } + + override def spill(): Long = { + var evicted: Option[T] = None + val thisThreadSpills = synchronized { + if (!closed && dev.isDefined && pinCount == 0 && !spilling) { + spilling = true + evicted = dev + dev = None + true + } else { + false + } + } + if (thisThreadSpills) { + SpillFramework.removeFromDeviceStore(this) + var shouldClose = false + executeSpill { + synchronized { + pendingRelease = pendingRelease ++ evicted.toSeq + spilling = false + shouldClose = closed + } + 0L + } + if (shouldClose) { + doClose() + } + approxSizeInBytes + } else { + 0L + } + } + + override def releaseSpilled(): Unit = { + val toClose = synchronized { + val release = pendingRelease + pendingRelease = Seq.empty + release + } + toClose.safeClose() + } + + override def doClose(): Unit = { + SpillFramework.removeFromDeviceStore(this) + val toClose = synchronized { + val current = dev + val release = pendingRelease + dev = None + pendingRelease = Seq.empty + current.toSeq ++ release.toSeq + } + toClose.safeClose() + } +} + /** * Spillable handles that can be materialized on the host. * @tparam T an auto closeable subclass. `host` tracks an instance of this object, @@ -1739,7 +1913,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) override protected def spillNvtxRange: NvtxId = NvtxRegistry.DISK_SPILL } -class SpillableDeviceStore extends SpillableStore[DeviceSpillableHandle[_]] { +class SpillableDeviceStore extends SpillableStore[DeviceStoreHandle] { override protected def spillNvtxRange: NvtxId = NvtxRegistry.DEVICE_SPILL override def postSpill(plan: SpillPlan): Unit = { @@ -2152,7 +2326,7 @@ object SpillFramework extends Logging { // if the stores have already shut down, we don't want to create them here // so we use `storesInternal` directly in these remove functions. - private[spill] def removeFromDeviceStore(handle: DeviceSpillableHandle[_]): Unit = { + private[spill] def removeFromDeviceStore(handle: DeviceStoreHandle): Unit = { synchronized { Option(storesInternal).map(_.deviceStore) }.foreach(_.remove(handle)) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala index 2ca8ac0bf9e..142cf9bca2e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala @@ -35,6 +35,8 @@ trait ExecutionPlanCaptureCallbackBase { def assertContainsAnsiCast(df: DataFrame): Unit def assertNotContain(gpuPlan: SparkPlan, className: String): Unit def assertNotContain(df: DataFrame, gpuClass: String): Unit + def sumMetric(gpuPlan: SparkPlan, className: String, metricName: String): Long + def sumMetric(df: DataFrame, className: String, metricName: String): Long def assertDidFallBack(gpuPlan: SparkPlan, fallbackCpuClass: String): Unit def assertDidFallBack(df: DataFrame, fallbackCpuClass: String): Unit def assertDidFallBack(gpuPlans: Array[SparkPlan], fallbackCpuClass: String): Unit @@ -85,6 +87,12 @@ object ExecutionPlanCaptureCallback extends ExecutionPlanCaptureCallbackBase { override def assertNotContain(df: DataFrame, gpuClass: String): Unit = impl.assertNotContain(df, gpuClass) + override def sumMetric(gpuPlan: SparkPlan, className: String, metricName: String): Long = + impl.sumMetric(gpuPlan, className, metricName) + + override def sumMetric(df: DataFrame, className: String, metricName: String): Long = + impl.sumMetric(df, className, metricName) + override def assertDidFallBack(gpuPlan: SparkPlan, fallbackCpuClass: String): Unit = impl.assertDidFallBack(gpuPlan, fallbackCpuClass) @@ -128,4 +136,4 @@ class ExecutionPlanCaptureCallback extends QueryExecutionListener { trait AdaptiveSparkPlanHelperShim { def collect[B](p: SparkPlan)(pf: PartialFunction[SparkPlan, B]): Seq[B] -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala index ccc430c65b9..98840b5272c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala @@ -185,6 +185,21 @@ class ShimmedExecutionPlanCaptureCallbackImpl extends ExecutionPlanCaptureCallba assertNotContain(executedPlan, gpuClass) } + override def sumMetric(gpuPlan: SparkPlan, className: String, metricName: String): Long = { + val executedPlan = extractExecutedPlan(gpuPlan) + val matchingPlans = collectPlansMatching(executedPlan, p => PlanUtils.sameClass(p, className)) + assert(matchingPlans.nonEmpty, s"Could not find $className in the Spark plan\n$executedPlan") + matchingPlans.map { plan => + plan.metrics.getOrElse(metricName, { + throw new AssertionError(s"Could not find metric $metricName on plan node\n$plan") + }).value + }.sum + } + + override def sumMetric(df: DataFrame, className: String, metricName: String): Long = { + sumMetric(df.queryExecution.executedPlan, className, metricName) + } + override def assertContainsAnsiCast(df: DataFrame): Unit = { val executedPlan = extractExecutedPlan(df.queryExecution.executedPlan) assert(containsPlanMatching(executedPlan, @@ -253,5 +268,23 @@ class ShimmedExecutionPlanCaptureCallbackImpl extends ExecutionPlanCaptureCallba case p => p.children.exists(plan => containsPlanMatching(plan, f)) }.nonEmpty -} + private def collectPlansMatching(plan: SparkPlan, f: SparkPlan => Boolean): Seq[SparkPlan] = { + val matching = ArrayBuffer[SparkPlan]() + + def recurse(currentPlan: SparkPlan): Unit = currentPlan match { + case p: AdaptiveSparkPlanExec => recurse(p.executedPlan) + case p: QueryStageExec => recurse(p.plan) + case p: ReusedSubqueryExec => recurse(p.child) + case p: ReusedExchangeExec => recurse(p.child) + case p => + if (f(p)) { + matching += p + } + p.children.foreach(recurse) + } + + recurse(plan) + matching.toSeq + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala new file mode 100644 index 00000000000..cfcaadc4255 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution + +import ai.rapids.cudf.{HashJoin => CudfHashJoin, Table} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, SpillableColumnarBatch} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit +import com.nvidia.spark.rapids.jni.DistinctHashJoin +import com.nvidia.spark.rapids.spill.SharedRecomputableDeviceHandle + +import org.apache.spark.sql.vectorized.ColumnarBatch + +sealed trait CachedBuildSide extends AutoCloseable { + def buildStats: JoinBuildSideStats +} + +final class CachedHashJoin( + override val buildStats: JoinBuildSideStats, + val handle: SharedRecomputableDeviceHandle[CudfHashJoin]) extends CachedBuildSide { + override def close(): Unit = handle.close() +} + +final class CachedDistinctHashJoin( + override val buildStats: JoinBuildSideStats, + val handle: SharedRecomputableDeviceHandle[DistinctHashJoin]) extends CachedBuildSide { + override def close(): Unit = handle.close() +} + +case class BroadcastCachedBuildSideKey( + projectedBuildKeys: Seq[String], + compareNullsEqual: Boolean, + filterOutNulls: Boolean) + +object BroadcastCachedBuildSide { + def key( + boundBuiltKeys: Seq[GpuExpression], + compareNullsEqual: Boolean, + filterOutNulls: Boolean): BroadcastCachedBuildSideKey = { + BroadcastCachedBuildSideKey( + boundBuiltKeys.map(_.canonicalized.sql), + compareNullsEqual, + filterOutNulls) + } + + private def withBuildKeys[T]( + broadcastBatch: SpillableColumnarBatch, + boundBuiltKeys: Seq[GpuExpression], + filterOutNulls: Boolean)(f: Table => T): T = { + def projectAndApply(cb: ColumnarBatch): T = { + withResource(GpuProjectExec.project(cb, boundBuiltKeys)) { buildKeys => + withResource(GpuColumnVector.from(buildKeys)) { buildKeysTable => + f(buildKeysTable) + } + } + } + if (filterOutNulls) { + val retainedBatch = broadcastBatch.incRefCount() + withResource(GpuHashJoin.filterNullsWithRetryAndClose(retainedBatch, boundBuiltKeys)) { + projectAndApply + } + } else { + val retainedBatch = broadcastBatch.incRefCount() + withRetryNoSplit(retainedBatch) { _ => + withResource(retainedBatch.getColumnarBatch()) { projectAndApply } + } + } + } + + /** + * cuDF's reusable hash join handles are safe for concurrent probes. The executor-wide cache + * therefore pins the live handle while a task is probing it and relies on + * `SharedRecomputableDeviceHandle` to track spillability with an application-level pin count. + */ + def create( + broadcastBatch: SpillableColumnarBatch, + boundBuiltKeys: Seq[GpuExpression], + compareNullsEqual: Boolean, + filterOutNulls: Boolean): CachedBuildSide = { + def buildHashJoin(): CudfHashJoin = { + withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => + new CudfHashJoin(buildKeys, compareNullsEqual) + } + } + + def buildDistinctHashJoin(): DistinctHashJoin = { + withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => + new DistinctHashJoin(buildKeys, compareNullsEqual) + } + } + + withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => + val stats = JoinBuildSideStats.fromTable(buildKeys) + val approxSizeInBytes = buildKeys.getDeviceMemorySize + if (stats.isDistinct) { + new CachedDistinctHashJoin( + stats, + SharedRecomputableDeviceHandle( + approxSizeInBytes, + new DistinctHashJoin(buildKeys, compareNullsEqual)) { + buildDistinctHashJoin() + }) + } else { + new CachedHashJoin( + stats, + SharedRecomputableDeviceHandle( + approxSizeInBytes, + new CudfHashJoin(buildKeys, compareNullsEqual)) { + buildHashJoin() + }) + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 8420a482cdf..22320d2bc2e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -80,9 +80,17 @@ class SerializeConcatHostBuffersDeserializeBatch( // used for memoization of deserialization to GPU on Executor @transient private var batchInternal: SpillableColumnarBatch = null + @transient private var cachedBuildSideCache: mutable.HashMap[BroadcastCachedBuildSideKey, CachedBuildSide] = null private def maybeGpuBatch: Option[SpillableColumnarBatch] = Option(batchInternal) + private def cachedBuildSides: mutable.HashMap[BroadcastCachedBuildSideKey, CachedBuildSide] = { + if (cachedBuildSideCache == null) { + cachedBuildSideCache = mutable.HashMap.empty + } + cachedBuildSideCache + } + def batch: SpillableColumnarBatch = this.synchronized { maybeGpuBatch.getOrElse { NvtxRegistry.BROADCAST_MANIFEST_BATCH { @@ -148,6 +156,31 @@ class SerializeConcatHostBuffersDeserializeBatch( } } + def getCachedBuildSide( + boundBuiltKeys: Seq[GpuExpression], + compareNullsEqual: Boolean, + filterOutNulls: Boolean, + cacheBuilds: GpuMetric = NoopMetric, + cacheHits: GpuMetric = NoopMetric): CachedBuildSide = this.synchronized { + val cacheKey = BroadcastCachedBuildSide.key( + boundBuiltKeys, + compareNullsEqual, + filterOutNulls) + cachedBuildSides.get(cacheKey).map { cached => + cacheHits += 1 + cached + }.getOrElse { + cacheBuilds += 1 + val cached = BroadcastCachedBuildSide.create( + batch, + boundBuiltKeys, + compareNullsEqual, + filterOutNulls) + cachedBuildSides.put(cacheKey, cached) + cached + } + } + private def writeObject(out: ObjectOutputStream): Unit = { doWriteObject(out) } @@ -246,8 +279,10 @@ class SerializeConcatHostBuffersDeserializeBatch( */ def closeInternal(): Unit = this.synchronized { Seq(data, batchInternal).safeClose() + Option(cachedBuildSideCache).foreach(cache => cache.values.toSeq.safeClose()) data = null batchInternal = null + cachedBuildSideCache = null } @scala.annotation.nowarn("msg=method finalize in class Object is deprecated") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala index 52aa61154d5..4b113fe7168 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala @@ -119,7 +119,11 @@ abstract class GpuBroadcastHashJoinExecBase( override lazy val additionalMetrics: Map[String, GpuMetric] = Map( OP_TIME_LEGACY -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_OP_TIME_LEGACY), STREAM_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_STREAM_TIME), - JOIN_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_JOIN_TIME)) + JOIN_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_JOIN_TIME), + BUILD_SIDE_CACHE_BUILDS -> + createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), + BUILD_SIDE_CACHE_HITS -> + createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS)) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) @@ -152,6 +156,8 @@ abstract class GpuBroadcastHashJoinExecBase( val opTime = gpuLongMetric(OP_TIME_LEGACY) val streamTime = gpuLongMetric(STREAM_TIME) val joinTime = gpuLongMetric(JOIN_TIME) + val buildSideCacheBuilds = gpuLongMetric(BUILD_SIDE_CACHE_BUILDS) + val buildSideCacheHits = gpuLongMetric(BUILD_SIDE_CACHE_HITS) val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) val joinOptions = RapidsConf.getJoinOptions(conf, targetSize) @@ -168,6 +174,10 @@ abstract class GpuBroadcastHashJoinExecBase( broadcastRelation, buildSchema, new CollectTimeIterator(NvtxRegistry.BROADCAST_JOIN_STREAM, it, streamTime)) + val broadcastBatch = broadcastRelation.value match { + case batch: SerializeConcatHostBuffersDeserializeBatch => Some(batch) + case _ => None + } if (localIsNullAwareAntiJoin) { // This is to support the null-aware anti join for the LeftAnti join with // BuildRight. See the config "spark.sql.optimizeNullAwareAntiJoin". @@ -188,12 +198,18 @@ abstract class GpuBroadcastHashJoinExecBase( boundStreamKeys) } doJoin(builtBatch, nullFilteredStreamIter, joinOptions, numOutputRows, - numOutputBatches, opTime, joinTime, enableBuildSideReuse) + numOutputBatches, opTime, joinTime, enableBuildSideReuse, + broadcastBatch = broadcastBatch, + buildSideCacheBuilds = buildSideCacheBuilds, + buildSideCacheHits = buildSideCacheHits) } } else { // builtBatch will be closed in doJoin doJoin(builtBatch, streamIter, joinOptions, numOutputRows, numOutputBatches, opTime, - joinTime, enableBuildSideReuse) + joinTime, enableBuildSideReuse, + broadcastBatch = broadcastBatch, + buildSideCacheBuilds = buildSideCacheBuilds, + buildSideCacheHits = buildSideCacheHits) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index ed7d01ebda6..8f5cb83c58f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -1088,6 +1088,13 @@ case class JoinCardinalityStats( case class JoinBuildSideStats(streamMagnificationFactor: Double, isDistinct: Boolean) object JoinBuildSideStats { + def fromTable(buildKeys: Table): JoinBuildSideStats = { + val builtCount = buildKeys.distinctCount(NullEquality.EQUAL) + val isDistinct = builtCount == buildKeys.getRowCount + val magnificationFactor = buildKeys.getRowCount.toDouble / builtCount + JoinBuildSideStats(magnificationFactor, isDistinct) + } + def fromBatch(batch: ColumnarBatch, boundBuildKeys: Seq[GpuExpression]): JoinBuildSideStats = { // This is okay because the build keys must be deterministic @@ -1096,10 +1103,7 @@ object JoinBuildSideStats { // will be for each input row on the stream side. This does not take into account // the join type, data skew or even if the keys actually match. withResource(GpuColumnVector.from(buildKeys)) { keysTable => - val builtCount = keysTable.distinctCount(NullEquality.EQUAL) - val isDistinct = builtCount == buildKeys.numRows() - val magnificationFactor = buildKeys.numRows().toDouble / builtCount - JoinBuildSideStats(magnificationFactor, isDistinct) + fromTable(keysTable) } } } @@ -1116,6 +1120,7 @@ abstract class BaseHashJoinIterator( built: LazySpillableColumnarBatch, boundBuiltKeys: Seq[GpuExpression], buildStatsOpt: Option[JoinBuildSideStats], + cachedBuildSide: Option[CachedBuildSide], stream: Iterator[LazySpillableColumnarBatch], boundStreamKeys: Seq[GpuExpression], streamAttributes: Seq[Attribute], @@ -1137,7 +1142,8 @@ abstract class BaseHashJoinIterator( opTime = opTime, joinTime = joinTime) { // We can cache this because the build side is not changing - protected lazy val buildStats: JoinBuildSideStats = buildStatsOpt.getOrElse { + protected lazy val buildStats: JoinBuildSideStats = buildStatsOpt + .orElse(cachedBuildSide.map(_.buildStats)).getOrElse { joinType match { case _: InnerLike | LeftOuter | RightOuter | FullOuter | LeftSemi | LeftAnti => built.checkpoint() @@ -1152,59 +1158,58 @@ abstract class BaseHashJoinIterator( } } - private[this] var buildSideReuseDisabled = !enableBuildSideReuse - private[this] var cachedHashJoin: Option[CudfHashJoin] = None - private[this] var cachedDistinctHashJoin: Option[DistinctHashJoin] = None + private[this] var cachedBuildSideDisabled = !enableBuildSideReuse || cachedBuildSide.isEmpty - private def createCachedBuildSideReuseHandle[T <: AutoCloseable](factory: Table => T): T = { - built.checkpoint() - withRetryNoSplit { - withRestoreOnRetry(built) { - withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => - try { - withResource(GpuColumnVector.from(builtKeys)) { builtKeysTable => - factory(builtKeysTable) - } - } finally { - built.allowSpilling() - } - } - } - } + private def canUseCachedBuildSide(expectedBuildSide: GpuBuildSide): Boolean = { + !cachedBuildSideDisabled && + expectedBuildSide == buildSide && + cachedBuildSide.isDefined + } + + protected def canUseCachedHashJoin(expectedBuildSide: GpuBuildSide): Boolean = { + canUseCachedBuildSide(expectedBuildSide) && cachedBuildSide.exists(_.isInstanceOf[CachedHashJoin]) + } + + protected def canUseCachedDistinctHashJoin(expectedBuildSide: GpuBuildSide): Boolean = { + canUseCachedBuildSide(expectedBuildSide) && + cachedBuildSide.exists(_.isInstanceOf[CachedDistinctHashJoin]) } - protected def cachedHashJoinFor(expectedBuildSide: GpuBuildSide): Option[CudfHashJoin] = { - if (buildSideReuseDisabled || buildStats.isDistinct || expectedBuildSide != buildSide) { + protected def withCachedHashJoin[T](expectedBuildSide: GpuBuildSide)(f: CudfHashJoin => T): Option[T] = { + if (!canUseCachedHashJoin(expectedBuildSide)) { None } else { - if (cachedHashJoin.isEmpty) { - cachedHashJoin = Some( - createCachedBuildSideReuseHandle(buildKeys => new CudfHashJoin(buildKeys, compareNullsEqual))) + cachedBuildSide.flatMap { cached => + cached match { + case hashJoin: CachedHashJoin => + withResource(hashJoin.handle.acquire()) { lease => + Some(f(lease.resource)) + } + case _ => None + } } - cachedHashJoin } } - protected def cachedDistinctHashJoinFor(expectedBuildSide: GpuBuildSide): Option[DistinctHashJoin] = { - if (buildSideReuseDisabled || !buildStats.isDistinct || expectedBuildSide != buildSide) { + protected def withCachedDistinctHashJoin[T]( + expectedBuildSide: GpuBuildSide)(f: DistinctHashJoin => T): Option[T] = { + if (!canUseCachedDistinctHashJoin(expectedBuildSide)) { None } else { - if (cachedDistinctHashJoin.isEmpty) { - cachedDistinctHashJoin = Some( - createCachedBuildSideReuseHandle(buildKeys => - new DistinctHashJoin(buildKeys, compareNullsEqual))) + cachedBuildSide.flatMap { cached => + cached match { + case distinctHashJoin: CachedDistinctHashJoin => + withResource(distinctHashJoin.handle.acquire()) { lease => + Some(f(lease.resource)) + } + case _ => None + } } - cachedDistinctHashJoin } } - protected def disableBuildSideReuse(): Unit = { - buildSideReuseDisabled = true - val hashJoinToClose = cachedHashJoin - cachedHashJoin = None - val distinctHashJoinToClose = cachedDistinctHashJoin - cachedDistinctHashJoin = None - Seq(hashJoinToClose, distinctHashJoinToClose).flatten.safeClose() + protected def disableCachedBuildSide(): Unit = { + cachedBuildSideDisabled = true } /** @@ -1427,7 +1432,7 @@ abstract class BaseHashJoinIterator( || joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter => - disableBuildSideReuse() + disableCachedBuildSide() // Because this is just an estimate, it is possible for us to get this wrong, so // make sure we at least split the batch in half. val numBatches = Math.max(2, estimatedNumBatches(spillOnlyCb)) @@ -1496,7 +1501,7 @@ abstract class BaseHashJoinIterator( override def close(): Unit = { if (!closed) { - disableBuildSideReuse() + disableCachedBuildSide() super.close() } } @@ -1531,11 +1536,13 @@ class HashJoinIterator( conditionForLogging: Option[Expression], opTime: GpuMetric, private val joinTime: GpuMetric, - enableBuildSideReuse: Boolean = false) + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) extends BaseHashJoinIterator( built, boundBuiltKeys, buildStatsOpt, + cachedBuildSide, stream, boundStreamKeys, streamAttributes, @@ -1566,7 +1573,7 @@ class HashJoinIterator( } } catch { case _: OutOfMemoryError | _: GpuOOM => - disableBuildSideReuse() + disableCachedBuildSide() fallback } finally { cb.allowSpilling() @@ -1610,18 +1617,22 @@ class HashJoinIterator( case _: InnerLike => buildSide match { case GpuBuildRight => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.innerHashJoinBuildRightRowCount(streamKeys, hashJoin)) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.innerHashJoinBuildRightRowCount(streamKeys, hashJoin) + } case GpuBuildLeft => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.innerHashJoinBuildLeftRowCount(streamKeys, hashJoin)) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.innerHashJoinBuildLeftRowCount(streamKeys, hashJoin) + } } case LeftOuter => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.leftOuterHashJoinBuildRightRowCount(streamKeys, hashJoin)) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.leftOuterHashJoinBuildRightRowCount(streamKeys, hashJoin) + } case RightOuter => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeftRowCount(streamKeys, hashJoin)) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.rightOuterHashJoinBuildLeftRowCount(streamKeys, hashJoin) + } case _ => None } @@ -1633,17 +1644,19 @@ class HashJoinIterator( outputRowCount: Option[Long]): Option[GatherMapsResult] = { buildSide match { case GpuBuildRight => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin, outputRowCount)) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin, outputRowCount) + } case GpuBuildLeft => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin, outputRowCount)) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin, outputRowCount) + } } } private def reusedGenericLeftSemi( leftKeys: Table): Option[GatherMapsResult] = { - cachedHashJoinFor(GpuBuildRight).map { hashJoin => + withCachedHashJoin(GpuBuildRight) { hashJoin => withResource(JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) { innerMaps => JoinImpl.makeLeftSemi(innerMaps, leftKeys.getRowCount.toInt) } @@ -1652,7 +1665,7 @@ class HashJoinIterator( private def reusedGenericLeftAnti( leftKeys: Table): Option[GatherMapsResult] = { - cachedHashJoinFor(GpuBuildRight).map { hashJoin => + withCachedHashJoin(GpuBuildRight) { hashJoin => withResource(JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) { innerMaps => JoinImpl.makeLeftAnti(innerMaps, leftKeys.getRowCount.toInt) } @@ -1662,7 +1675,7 @@ class HashJoinIterator( private def computeDistinctJoin( leftKeys: Table, rightKeys: Table): GatherMapsResult = { - val reused = cachedDistinctHashJoinFor(buildSide).map { distinctHashJoin => + val reused = withCachedDistinctHashJoin(buildSide) { distinctHashJoin => logJoinCardinality(leftKeys, rightKeys, "distinct (reused)") joinType match { case LeftOuter => @@ -1800,12 +1813,14 @@ class HashJoinIterator( val result = joinType match { case LeftOuter => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin, numJoinRows)) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin, numJoinRows) + } .getOrElse(JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case RightOuter => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin, numJoinRows)) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin, numJoinRows) + } .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case _: InnerLike => cachedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { @@ -1846,11 +1861,13 @@ class ConditionalHashJoinIterator( conditionForLogging: Option[Expression], opTime: GpuMetric, joinTime: GpuMetric, - enableBuildSideReuse: Boolean = false) + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) extends BaseHashJoinIterator( built, boundBuiltKeys, buildStatsOpt, + cachedBuildSide, stream, boundStreamKeys, streamAttributes, @@ -2066,11 +2083,13 @@ class HashJoinStreamSideIterator( conditionForLogging: Option[Expression], opTime: GpuMetric, joinTime: GpuMetric, - enableBuildSideReuse: Boolean = false) + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) extends BaseHashJoinIterator( built, boundBuiltKeys, buildStatsOpt, + cachedBuildSide, stream, boundStreamKeys, streamAttributes, @@ -2118,11 +2137,13 @@ class HashJoinStreamSideIterator( rightKeys: Table): Option[GatherMapsResult] = { cudfBuildSide match { case GpuBuildRight => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin) + } case GpuBuildLeft => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin)) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin) + } } } @@ -2130,12 +2151,14 @@ class HashJoinStreamSideIterator( leftKeys: Table, rightKeys: Table, originalJoinType: Option[JoinType]): GatherMapsResult = { - val implName = cachedDistinctHashJoinFor(cudfBuildSide) - .map(_ => s"distinct (outer: $joinType, reused)") - .getOrElse(s"distinct (outer: $joinType)") + val implName = if (canUseCachedDistinctHashJoin(cudfBuildSide)) { + s"distinct (outer: $joinType, reused)" + } else { + s"distinct (outer: $joinType)" + } logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) - val result = cachedDistinctHashJoinFor(cudfBuildSide).map { distinctHashJoin => + val result = withCachedDistinctHashJoin(cudfBuildSide) { distinctHashJoin => subJoinType match { case LeftOuter => JoinImpl.leftOuterDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) @@ -2271,12 +2294,14 @@ class HashJoinStreamSideIterator( val result = subJoinType match { case LeftOuter => - cachedHashJoinFor(GpuBuildRight) - .map(hashJoin => JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin)) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin) + } .getOrElse(JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case RightOuter => - cachedHashJoinFor(GpuBuildLeft) - .map(hashJoin => JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin)) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin) + } .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case Inner => cachedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { @@ -2603,13 +2628,15 @@ class HashOuterJoinIterator( conditionForLogging: Option[Expression], opTime: GpuMetric, joinTime: GpuMetric, - enableBuildSideReuse: Boolean = false) extends Iterator[ColumnarBatch] with TaskAutoCloseableResource { + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) + extends Iterator[ColumnarBatch] with TaskAutoCloseableResource { private val streamJoinIter = new HashJoinStreamSideIterator(joinType, built, boundBuiltKeys, - buildStats, buildSideTrackerInit, stream, boundStreamKeys, streamAttributes, + buildStats, buildSideTrackerInit, stream, boundStreamKeys, streamAttributes, lazyCompiledCondition, joinOptions, buildSide, compareNullsEqual, conditionForLogging, opTime, joinTime, - enableBuildSideReuse) + enableBuildSideReuse, cachedBuildSide) private var finalBatch: Option[ColumnarBatch] = None @@ -2914,10 +2941,23 @@ trait GpuHashJoin extends GpuJoinExec { numOutputBatches: GpuMetric, opTime: GpuMetric, joinTime: GpuMetric, - enableBuildSideReuse: Boolean): Iterator[ColumnarBatch] = { + enableBuildSideReuse: Boolean, + broadcastBatch: Option[SerializeConcatHostBuffersDeserializeBatch] = None, + buildSideCacheBuilds: GpuMetric = NoopMetric, + buildSideCacheHits: GpuMetric = NoopMetric): Iterator[ColumnarBatch] = { val filterOutNull = GpuHashJoin.buildSideNeedsNullFilter(joinType, compareNullsEqual, buildSide, buildKeys) + val cachedBuildSide = if (enableBuildSideReuse) { + broadcastBatch.map(_.getCachedBuildSide( + boundBuildKeys, + compareNullsEqual, + filterOutNull, + buildSideCacheBuilds, + buildSideCacheHits)) + } else { + None + } val nullFiltered = if (filterOutNull) { val sb = closeOnExcept(builtBatch)( @@ -2960,10 +3000,13 @@ trait GpuHashJoin extends GpuJoinExec { val lazyCond = boundConditionLeftRight.map { cond => LazyCompiledCondition(cond, left.output.size, right.output.size) } - new HashOuterJoinIterator(joinType, spillableBuiltBatch, boundBuildKeys, None, None, + new HashOuterJoinIterator(joinType, spillableBuiltBatch, boundBuildKeys, None, + None, lazyStream, boundStreamKeys, streamedPlan.output, lazyCond, joinOptions, buildSide, - compareNullsEqual, condition, opTime, joinTime, enableBuildSideReuse) + compareNullsEqual, condition, opTime, joinTime, + enableBuildSideReuse = enableBuildSideReuse, + cachedBuildSide = cachedBuildSide) case _ => if (boundConditionLeftRight.isDefined) { // ConditionalHashJoinIterator will close the LazyCompiledCondition @@ -2975,12 +3018,15 @@ trait GpuHashJoin extends GpuJoinExec { new ConditionalHashJoinIterator(spillableBuiltBatch, boundBuildKeys, None, lazyStream, boundStreamKeys, streamedPlan.output, lazyCond, joinOptions, joinType, buildSide, - compareNullsEqual, condition, opTime, joinTime, enableBuildSideReuse) + compareNullsEqual, condition, opTime, joinTime, + enableBuildSideReuse = enableBuildSideReuse, + cachedBuildSide = cachedBuildSide) } else { new HashJoinIterator(spillableBuiltBatch, boundBuildKeys, None, lazyStream, boundStreamKeys, streamedPlan.output, joinOptions, joinType, buildSide, compareNullsEqual, condition, opTime, joinTime, - enableBuildSideReuse) + enableBuildSideReuse = enableBuildSideReuse, + cachedBuildSide = cachedBuildSide) } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 52a948c6f0b..6ed9967cdb7 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -96,6 +96,8 @@ case class GpuBroadcastHashJoinExec( NUM_INPUT_ROWS -> createMetric(DEBUG_LEVEL, DESCRIPTION_NUM_INPUT_ROWS), NUM_INPUT_BATCHES -> createMetric(DEBUG_LEVEL, DESCRIPTION_NUM_INPUT_BATCHES), CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME), + BUILD_SIDE_CACHE_BUILDS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), + BUILD_SIDE_CACHE_HITS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS), ) override def requiredChildDistribution: Seq[Distribution] = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index 93b23df34b5..20e6c54f7f0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -20,6 +20,7 @@ import com.nvidia.spark.rapids.TestUtils.findOperator import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuHashJoin} @@ -148,6 +149,14 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { (build, probe) => broadcast(build).join(probe, Seq("join_key"), "right") } + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct full outer build right", + nullableProbeDf, + nullableDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "fullouter") + } + IGNORE_ORDER_testSparkResultsAreEqual2( "broadcast hash join reuse non-distinct inner build right", streamedProbeDf, @@ -214,4 +223,24 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { probe("join_key") === build("join_key") && probe("probe_value") <= build("build_value"), "inner") } + + test("broadcast hash join reuse same broadcast in multiple joins plan") { + val conf = broadcastReuseConf.clone().set("spark.sql.exchange.reuse", "true") + withGpuSparkSession(spark => { + val probe = streamedProbeDf(spark) + val build = broadcast(distinctBuildDf(spark)) + val joined = probe + .join(build, Seq("join_key"), "inner") + .select("join_key", "probe_value") + .join(build, Seq("join_key"), "inner") + .select("join_key", "probe_value") + + joined.collect() + val plan = joined.queryExecution.executedPlan + val bhjCount = PlanUtils.findOperators(plan, _.isInstanceOf[GpuBroadcastHashJoinExec]) + val reusedExchanges = PlanUtils.findOperators(plan, _.isInstanceOf[ReusedExchangeExec]) + assertResult(2)(bhjCount.size) + assert(reusedExchanges.nonEmpty) + }, conf) + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala new file mode 100644 index 00000000000..3ac82a92c19 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spill + +import java.util.concurrent.{Callable, CountDownLatch, Executors, TimeUnit} + +import scala.collection.mutable.ArrayBuffer + +import com.nvidia.spark.rapids.Arm.withResource + +class SharedRecomputableDeviceHandleSuite extends SpillUnitTestBase { + private class TestResource(val id: Int, closedIds: ArrayBuffer[Int]) extends AutoCloseable { + private var closed = false + + def isClosed: Boolean = closed + + override def close(): Unit = { + if (closed) { + throw new IllegalStateException(s"resource $id closed twice") + } + closed = true + closedIds += id + } + } + + test("recomputable handle uses pin count for spillability and rebuilds after spill") { + val closedIds = ArrayBuffer[Int]() + var buildCount = 0 + + def buildResource(): TestResource = { + buildCount += 1 + new TestResource(buildCount, closedIds) + } + + withResource( + new SharedRecomputableDeviceHandle(1024L, buildResource(), () => buildResource())) { + handle => + SpillFramework.stores.deviceStore.track(handle) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + + withResource(handle.acquire()) { lease => + assertResult(1)(lease.resource.id) + assert(!handle.spillable) + assertResult(0L)(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) + } + + assert(handle.spillable) + assertResult(handle.approxSizeInBytes)( + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) + assertResult(Seq(1))(closedIds.toSeq) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + + withResource(handle.acquire()) { lease => + assertResult(2)(lease.resource.id) + assert(!lease.resource.isClosed) + } + + assertResult(2)(buildCount) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + + assertResult(Seq(1, 2))(closedIds.sorted.toSeq) + } + + test("releaseSpilled only closes evicted generations") { + val closedIds = ArrayBuffer[Int]() + var buildCount = 0 + + def buildResource(): TestResource = { + buildCount += 1 + new TestResource(buildCount, closedIds) + } + + withResource( + new SharedRecomputableDeviceHandle(1024L, buildResource(), () => buildResource())) { + handle => + SpillFramework.stores.deviceStore.track(handle) + assertResult(handle.approxSizeInBytes)(handle.spill()) + + withResource(handle.acquire()) { lease => + assertResult(2)(lease.resource.id) + assert(!lease.resource.isClosed) + } + + assertResult(handle.approxSizeInBytes)(handle.spill()) + handle.releaseSpilled() + assertResult(Seq(1, 2))(closedIds.toSeq.sorted) + + withResource(handle.acquire()) { lease => + assertResult(3)(lease.resource.id) + assert(!lease.resource.isClosed) + } + } + + assertResult(Seq(1, 2, 3))(closedIds.sorted.toSeq) + } + + test("concurrent acquires only rebuild once after eviction") { + val closedIds = ArrayBuffer[Int]() + var buildCount = 0 + val rebuildStarted = new CountDownLatch(1) + val allowRebuild = new CountDownLatch(1) + + def buildResource(): TestResource = synchronized { + buildCount += 1 + val id = buildCount + if (id > 1) { + rebuildStarted.countDown() + assert(allowRebuild.await(30, TimeUnit.SECONDS)) + } + new TestResource(id, closedIds) + } + + withResource( + new SharedRecomputableDeviceHandle(1024L, buildResource(), () => buildResource())) { + handle => + SpillFramework.stores.deviceStore.track(handle) + assertResult(handle.approxSizeInBytes)(handle.spill()) + + val pool = Executors.newFixedThreadPool(2) + try { + val futures = (0 until 2).map { _ => + pool.submit(new Callable[Int] { + override def call(): Int = { + withResource(handle.acquire()) { lease => + lease.resource.id + } + } + }) + } + + assert(rebuildStarted.await(30, TimeUnit.SECONDS)) + allowRebuild.countDown() + + val ids = futures.map(_.get(30, TimeUnit.SECONDS)).sorted + assertResult(Seq(2, 2))(ids) + assertResult(2)(buildCount) + } finally { + pool.shutdownNow() + assert(pool.awaitTermination(30, TimeUnit.SECONDS)) + } + } + + assertResult(Seq(1, 2))(closedIds.sorted.toSeq) + } +} From ff0e6f28d009cf9cd45ad3b426ea563af1fb637a Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Thu, 23 Apr 2026 14:06:25 -0700 Subject: [PATCH 06/12] Lots of cleanups --- integration_tests/src/main/python/join_test.py | 10 ++-------- .../nvidia/spark/rapids/GpuShuffledHashJoinExec.scala | 1 + .../scala/com/nvidia/spark/rapids/RapidsConf.scala | 9 ++++++--- .../com/nvidia/spark/rapids/spill/SpillFramework.scala | 2 +- .../execution/GpuBroadcastHashJoinExecBase.scala | 6 ++---- .../rapids/execution/GpuBroadcastHashJoinExec.scala | 1 + 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index a1aa7f93236..3c81454e9cb 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -251,10 +251,7 @@ def do_join(spark): 'GpuBroadcastHashJoinExec', 'buildSideCacheHits') assert cache_builds == 1 - if gpu_df.sparkSession.sparkContext.master.startswith("local"): - assert cache_hits >= 3 - else: - assert cache_hits > 0 + assert cache_hits > 0 @ignore_order(local=True) @@ -293,10 +290,7 @@ def do_join(spark): 'GpuBroadcastHashJoinExec', 'buildSideCacheHits') assert cache_builds == 1 - if gpu_df.sparkSession.sparkContext.master.startswith("local"): - assert cache_hits >= 3 - else: - assert cache_hits > 0 + assert cache_hits > 0 # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala index 76e0fe3842b..f0b5755a59a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala @@ -542,3 +542,4 @@ object GpuShuffledHashJoinExec extends Logging { retIter } } + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index cc7b8240a40..279c699b080 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -784,9 +784,12 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") val BROADCAST_HASH_TABLE_REUSE = conf("spark.rapids.sql.join.broadcastHashTable.reuse") - .doc("Enable reuse of broadcast-side hash table state across stream batches for " + - "broadcast hash joins. This only applies when the broadcast side remains the " + - "physical build side selected by the join implementation.") + .doc("Enable reuse of the broadcast-side hash table for broadcast hash joins. " + + "When enabled, the hash table is built once per broadcast and shared across all " + + "stream batches within a task and across all tasks that consume the same broadcast " + + "on an executor. Reuse pins the physical build side to the broadcast side for the " + + "lifetime of each cached join, overriding the dynamic build-side selection " + + s"heuristic configured by ${JOIN_BUILD_SIDE.key}.") .internal() .booleanConf .createWithDefault(false) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 79ae4bad0d9..3dd0122edce 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -352,7 +352,7 @@ object SharedRecomputableDeviceHandle { } /** - * Spill-framework handle for device-only state that is cheaper to recompute than to spill. + * Handle for device-only state that is cheaper to recompute than to spill. * * When this handle is selected for spilling, it does not copy anything to host or disk. Instead * it marks the current device state as evicted and returns `approxSizeInBytes` so the spill diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala index 4b113fe7168..44103d43168 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala @@ -120,10 +120,8 @@ abstract class GpuBroadcastHashJoinExecBase( OP_TIME_LEGACY -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_OP_TIME_LEGACY), STREAM_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_STREAM_TIME), JOIN_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_JOIN_TIME), - BUILD_SIDE_CACHE_BUILDS -> - createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), - BUILD_SIDE_CACHE_HITS -> - createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS)) + BUILD_SIDE_CACHE_BUILDS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), + BUILD_SIDE_CACHE_HITS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS)) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 6ed9967cdb7..1c262ea9960 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -223,3 +223,4 @@ case class GpuBroadcastHashJoinExec( } } } + From a65694fa24b0ccef8da540a489abd28ef6eb1f3f Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Thu, 23 Apr 2026 19:47:39 -0700 Subject: [PATCH 07/12] Make conf public --- .../src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 279c699b080..04874e34464 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -790,7 +790,6 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") "on an executor. Reuse pins the physical build side to the broadcast side for the " + "lifetime of each cached join, overriding the dynamic build-side selection " + s"heuristic configured by ${JOIN_BUILD_SIDE.key}.") - .internal() .booleanConf .createWithDefault(false) From f90e8d0b168dedf2729a821802053758f8c3b499 Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Thu, 23 Apr 2026 20:07:14 -0700 Subject: [PATCH 08/12] Use cudf distinct hash join wrapper --- .../execution/BroadcastCachedBuildSide.scala | 11 +++++----- .../sql/rapids/execution/GpuHashJoin.scala | 22 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala index cfcaadc4255..7bd1132bde1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala @@ -16,11 +16,10 @@ package org.apache.spark.sql.rapids.execution -import ai.rapids.cudf.{HashJoin => CudfHashJoin, Table} +import ai.rapids.cudf.{DistinctHashJoin => CudfDistinctHashJoin, HashJoin => CudfHashJoin, Table} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, SpillableColumnarBatch} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit -import com.nvidia.spark.rapids.jni.DistinctHashJoin import com.nvidia.spark.rapids.spill.SharedRecomputableDeviceHandle import org.apache.spark.sql.vectorized.ColumnarBatch @@ -37,7 +36,7 @@ final class CachedHashJoin( final class CachedDistinctHashJoin( override val buildStats: JoinBuildSideStats, - val handle: SharedRecomputableDeviceHandle[DistinctHashJoin]) extends CachedBuildSide { + val handle: SharedRecomputableDeviceHandle[CudfDistinctHashJoin]) extends CachedBuildSide { override def close(): Unit = handle.close() } @@ -97,9 +96,9 @@ object BroadcastCachedBuildSide { } } - def buildDistinctHashJoin(): DistinctHashJoin = { + def buildDistinctHashJoin(): CudfDistinctHashJoin = { withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => - new DistinctHashJoin(buildKeys, compareNullsEqual) + new CudfDistinctHashJoin(buildKeys, compareNullsEqual) } } @@ -111,7 +110,7 @@ object BroadcastCachedBuildSide { stats, SharedRecomputableDeviceHandle( approxSizeInBytes, - new DistinctHashJoin(buildKeys, compareNullsEqual)) { + new CudfDistinctHashJoin(buildKeys, compareNullsEqual)) { buildDistinctHashJoin() }) } else { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index 8f5cb83c58f..5a23f97b0b1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -15,13 +15,13 @@ */ package org.apache.spark.sql.rapids.execution -import ai.rapids.cudf.{ColumnView, DType, GatherMap, HashJoin => CudfHashJoin, NullEquality, OutOfBoundsPolicy, Scalar, Table} +import ai.rapids.cudf.{ColumnView, DistinctHashJoin => CudfDistinctHashJoin, DType, GatherMap, HashJoin => CudfHashJoin, NullEquality, OutOfBoundsPolicy, Scalar, Table} import ai.rapids.cudf.ast.CompiledExpression import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRestoreOnRetry, withRetryNoSplit} -import com.nvidia.spark.rapids.jni.{DistinctHashJoin, GpuOOM, JoinPrimitives} +import com.nvidia.spark.rapids.jni.{GpuOOM, JoinPrimitives} import com.nvidia.spark.rapids.shims.ShimBinaryExecNode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression} @@ -337,29 +337,29 @@ object JoinImpl { def innerDistinctHashJoinBuildLeft( rightKeys: Table, - leftHashJoin: DistinctHashJoin): GatherMapsResult = { - val arrayRet = leftHashJoin.innerJoinGatherMaps(rightKeys) + leftHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val arrayRet = rightKeys.innerJoinGatherMaps(leftHashJoin) GatherMapsResult(arrayRet(1), arrayRet(0)) } def innerDistinctHashJoinBuildRight( leftKeys: Table, - rightHashJoin: DistinctHashJoin): GatherMapsResult = { - val arrayRet = rightHashJoin.innerJoinGatherMaps(leftKeys) + rightHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val arrayRet = leftKeys.innerJoinGatherMaps(rightHashJoin) GatherMapsResult(arrayRet(0), arrayRet(1)) } def leftOuterDistinctHashJoinBuildRight( leftKeys: Table, - rightHashJoin: DistinctHashJoin): GatherMapsResult = { - val rightRet = rightHashJoin.leftJoinGatherMap(leftKeys) + rightHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val rightRet = leftKeys.leftDistinctJoinGatherMap(rightHashJoin) GatherMapsResult.makeFromRight(rightRet) } def rightOuterDistinctHashJoinBuildLeft( rightKeys: Table, - leftHashJoin: DistinctHashJoin): GatherMapsResult = { - val leftRet = leftHashJoin.leftJoinGatherMap(rightKeys) + leftHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val leftRet = rightKeys.leftDistinctJoinGatherMap(leftHashJoin) GatherMapsResult.makeFromLeft(leftRet) } @@ -1192,7 +1192,7 @@ abstract class BaseHashJoinIterator( } protected def withCachedDistinctHashJoin[T]( - expectedBuildSide: GpuBuildSide)(f: DistinctHashJoin => T): Option[T] = { + expectedBuildSide: GpuBuildSide)(f: CudfDistinctHashJoin => T): Option[T] = { if (!canUseCachedDistinctHashJoin(expectedBuildSide)) { None } else { From 5443f394f4d8ccff00b7f1b8b1f65076af604a91 Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Thu, 23 Apr 2026 21:37:04 -0700 Subject: [PATCH 09/12] nvtx ranges. --- .../spark/rapids/NvtxRangeWithDoc.scala | 6 ++++- .../execution/BroadcastCachedBuildSide.scala | 24 +++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala index b700e917cca..9f598b344df 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -393,6 +393,9 @@ object NvtxRegistry { val BUILD_JOIN_TABLE: NvtxId = NvtxId("build join table", NvtxColor.GREEN, "Building hash table for join operation") + val BROADCAST_HASH_TABLE_BUILD: NvtxId = NvtxId("broadcast hash table build", + NvtxColor.GREEN, "Building cuDF hash table for broadcast hash join") + // Window operations val WINDOW: NvtxId = NvtxId("window", NvtxColor.CYAN, "Computing window function results") @@ -780,6 +783,7 @@ object NvtxRegistry { register(EXISTENCE_JOIN_SCATTER_MAP) register(EXISTENCE_JOIN_BATCH) register(BUILD_JOIN_TABLE) + register(BROADCAST_HASH_TABLE_BUILD) register(WINDOW) register(RUNNING_WINDOW) register(DOUBLE_BATCHED_WINDOW_PRE) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala index 7bd1132bde1..6b7a81b12c0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.rapids.execution import ai.rapids.cudf.{DistinctHashJoin => CudfDistinctHashJoin, HashJoin => CudfHashJoin, Table} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, SpillableColumnarBatch} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, NvtxRegistry, SpillableColumnarBatch} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit import com.nvidia.spark.rapids.spill.SharedRecomputableDeviceHandle @@ -80,6 +80,20 @@ object BroadcastCachedBuildSide { } } + private def newHashJoin(buildKeys: Table, compareNullsEqual: Boolean): CudfHashJoin = { + NvtxRegistry.BROADCAST_HASH_TABLE_BUILD { + new CudfHashJoin(buildKeys, compareNullsEqual) + } + } + + private def newDistinctHashJoin( + buildKeys: Table, + compareNullsEqual: Boolean): CudfDistinctHashJoin = { + NvtxRegistry.BROADCAST_HASH_TABLE_BUILD { + new CudfDistinctHashJoin(buildKeys, compareNullsEqual) + } + } + /** * cuDF's reusable hash join handles are safe for concurrent probes. The executor-wide cache * therefore pins the live handle while a task is probing it and relies on @@ -92,13 +106,13 @@ object BroadcastCachedBuildSide { filterOutNulls: Boolean): CachedBuildSide = { def buildHashJoin(): CudfHashJoin = { withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => - new CudfHashJoin(buildKeys, compareNullsEqual) + newHashJoin(buildKeys, compareNullsEqual) } } def buildDistinctHashJoin(): CudfDistinctHashJoin = { withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => - new CudfDistinctHashJoin(buildKeys, compareNullsEqual) + newDistinctHashJoin(buildKeys, compareNullsEqual) } } @@ -110,7 +124,7 @@ object BroadcastCachedBuildSide { stats, SharedRecomputableDeviceHandle( approxSizeInBytes, - new CudfDistinctHashJoin(buildKeys, compareNullsEqual)) { + newDistinctHashJoin(buildKeys, compareNullsEqual)) { buildDistinctHashJoin() }) } else { @@ -118,7 +132,7 @@ object BroadcastCachedBuildSide { stats, SharedRecomputableDeviceHandle( approxSizeInBytes, - new CudfHashJoin(buildKeys, compareNullsEqual)) { + newHashJoin(buildKeys, compareNullsEqual)) { buildHashJoin() }) } From f9ffee02c84b680da47db64aa526fd0776a200ee Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Fri, 24 Apr 2026 12:02:02 -0700 Subject: [PATCH 10/12] Make cache metrics debug level --- .../sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala index 44103d43168..48eb0041a6e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala @@ -120,8 +120,8 @@ abstract class GpuBroadcastHashJoinExecBase( OP_TIME_LEGACY -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_OP_TIME_LEGACY), STREAM_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_STREAM_TIME), JOIN_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_JOIN_TIME), - BUILD_SIDE_CACHE_BUILDS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), - BUILD_SIDE_CACHE_HITS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS)) + BUILD_SIDE_CACHE_BUILDS -> createMetric(DEBUG_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), + BUILD_SIDE_CACHE_HITS -> createMetric(DEBUG_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS)) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) From 8d48e08ae13ea0244615d47cb59546fe262166bd Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Fri, 24 Apr 2026 13:56:38 -0700 Subject: [PATCH 11/12] Scalastyle --- .../spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 22320d2bc2e..055120bd3ad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -80,7 +80,8 @@ class SerializeConcatHostBuffersDeserializeBatch( // used for memoization of deserialization to GPU on Executor @transient private var batchInternal: SpillableColumnarBatch = null - @transient private var cachedBuildSideCache: mutable.HashMap[BroadcastCachedBuildSideKey, CachedBuildSide] = null + @transient private var cachedBuildSideCache: + mutable.HashMap[BroadcastCachedBuildSideKey, CachedBuildSide] = null private def maybeGpuBatch: Option[SpillableColumnarBatch] = Option(batchInternal) From 955261bee5eea86d3d17aaf85348ae7d2c85fe84 Mon Sep 17 00:00:00 2001 From: Rishi Chandra Date: Fri, 24 Apr 2026 13:57:57 -0700 Subject: [PATCH 12/12] License headers --- .../apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala | 2 +- .../spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala | 2 +- tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala index 142cf9bca2e..a7f3f59d323 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index 42d96f04ea9..509ff7fb385 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala index d04bb4c91b8..383cbf86b7c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.