diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index cc9013cd845..5d3aa3f1fba 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -528,6 +528,7 @@ def assert_cpu_and_gpu_are_equal_collect_with_capture(func, _sort_locally(from_cpu, from_gpu) assert_equal(from_cpu, from_gpu) + return gpu_df def assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun, sql, diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 500d49f9b81..3c81454e9cb 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -19,7 +19,7 @@ from asserts import (assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_row_counts_equal, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture, assert_cpu_and_gpu_are_equal_sql_with_capture, assert_gpu_and_cpu_are_equal_sql) -from conftest import is_emr_runtime +from conftest import is_emr_runtime, spark_jvm from data_gen import * from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, disable_ansi_mode from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime, is_spark_400_or_later, is_spark_411_or_later @@ -217,6 +217,82 @@ def do_join(spark): }) +@ignore_order(local=True) +@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) +def test_broadcast_hash_join_reuse_large_build_aqe(kudo_enabled): + def do_join(spark): + left = spark.range(4096, numPartitions=4).selectExpr( + "CAST(id % 64 AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + right = spark.range(8192).selectExpr( + "CAST(id % 64 AS INT) AS join_key", + "CAST(id * 3 AS INT) AS build_value") + return left.join(broadcast(right), "join_key", "inner").groupBy("join_key").count() + + conf = { + 'spark.sql.adaptive.enabled': 'true', + 'spark.sql.autoBroadcastJoinThreshold': '-1', + 'spark.sql.shuffle.partitions': '4', + 'spark.rapids.sql.batchSizeBytes': '1024', + 'spark.rapids.sql.join.broadcastHashTable.reuse': 'true', + kudo_enabled_conf_key: kudo_enabled + } + gpu_df = assert_cpu_and_gpu_are_equal_collect_with_capture( + do_join, + exist_classes='GpuBroadcastHashJoinExec', + conf=conf) + jvm = spark_jvm() + cache_builds = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheBuilds') + cache_hits = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheHits') + assert cache_builds == 1 + assert cache_hits > 0 + + +@ignore_order(local=True) +@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) +def test_broadcast_hash_join_reuse_same_broadcast_multi_join_aqe(kudo_enabled): + def do_join(spark): + fact = spark.range(512, numPartitions=4).selectExpr( + "CAST(id % 8 AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + dim = broadcast(spark.range(32).selectExpr( + "CAST(id % 8 AS INT) AS join_key", + "CAST(id AS INT) AS build_value")) + return fact.join(dim, "join_key", "inner").select("join_key", "probe_value") \ + .join(dim, "join_key", "inner").groupBy("join_key").count() + + conf = { + 'spark.sql.adaptive.enabled': 'true', + 'spark.sql.autoBroadcastJoinThreshold': '-1', + 'spark.sql.exchange.reuse': 'true', + 'spark.sql.shuffle.partitions': '4', + 'spark.rapids.sql.batchSizeBytes': '1024', + 'spark.rapids.sql.join.broadcastHashTable.reuse': 'true', + kudo_enabled_conf_key: kudo_enabled + } + gpu_df = assert_cpu_and_gpu_are_equal_collect_with_capture( + do_join, + exist_classes='GpuBroadcastHashJoinExec,ReusedExchangeExec', + conf=conf) + jvm = spark_jvm() + cache_builds = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheBuilds') + cache_hits = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric( + gpu_df._jdf, + 'GpuBroadcastHashJoinExec', + 'buildSideCacheHits') + assert cache_builds == 1 + assert cache_hits > 0 + + # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala index b32184d08ce..59d45f89dec 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala @@ -116,6 +116,8 @@ object GpuMetric extends Logging { val BUILD_DATA_SIZE = "buildDataSize" val BUILD_TIME = "buildTime" val STREAM_TIME = "streamTime" + val BUILD_SIDE_CACHE_BUILDS = "buildSideCacheBuilds" + val BUILD_SIDE_CACHE_HITS = "buildSideCacheHits" val NUM_TASKS_FALL_BACKED = "numTasksFallBacked" val NUM_TASKS_REPARTITIONED = "numTasksRepartitioned" val NUM_TASKS_SKIPPED_AGG = "numTasksSkippedAgg" @@ -173,6 +175,8 @@ object GpuMetric extends Logging { val DESCRIPTION_BUILD_DATA_SIZE = "build side size" val DESCRIPTION_BUILD_TIME = "build time" val DESCRIPTION_STREAM_TIME = "stream time" + val DESCRIPTION_BUILD_SIDE_CACHE_BUILDS = "cached build side builds" + val DESCRIPTION_BUILD_SIDE_CACHE_HITS = "cached build side hits" val DESCRIPTION_NUM_TASKS_FALL_BACKED = "number of sort fallback tasks" val DESCRIPTION_NUM_TASKS_REPARTITIONED = "number of tasks repartitioned for agg" val DESCRIPTION_NUM_TASKS_SKIPPED_AGG = "number of tasks skipped aggregation" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala index e7bf11febf7..f0b5755a59a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala @@ -264,7 +264,7 @@ case class GpuShuffledHashJoinExec( } // doJoin will close singleBatch doJoin(singleBatch, maybeBufferedStreamIter, joinOptions, - numOutputRows, numOutputBatches, opTime, joinTime) + numOutputRows, numOutputBatches, opTime, joinTime, enableBuildSideReuse = false) case Right(builtBatchIter) => // For big joins, when the build data can not fit into a single batch. val sizeBuildIter = builtBatchIter.map { cb => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala index b700e917cca..9f598b344df 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -393,6 +393,9 @@ object NvtxRegistry { val BUILD_JOIN_TABLE: NvtxId = NvtxId("build join table", NvtxColor.GREEN, "Building hash table for join operation") + val BROADCAST_HASH_TABLE_BUILD: NvtxId = NvtxId("broadcast hash table build", + NvtxColor.GREEN, "Building cuDF hash table for broadcast hash join") + // Window operations val WINDOW: NvtxId = NvtxId("window", NvtxColor.CYAN, "Computing window function results") @@ -780,6 +783,7 @@ object NvtxRegistry { register(EXISTENCE_JOIN_SCATTER_MAP) register(EXISTENCE_JOIN_BATCH) register(BUILD_JOIN_TABLE) + register(BROADCAST_HASH_TABLE_BUILD) register(WINDOW) register(RUNNING_WINDOW) register(DOUBLE_BATCHED_WINDOW_PRE) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 5cfc9cb09bf..04874e34464 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -782,6 +782,17 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .checkValues(JoinBuildSideSelection.values.map(_.toString)) .createWithDefault(JoinBuildSideSelection.AUTO.toString) + val BROADCAST_HASH_TABLE_REUSE = + conf("spark.rapids.sql.join.broadcastHashTable.reuse") + .doc("Enable reuse of the broadcast-side hash table for broadcast hash joins. " + + "When enabled, the hash table is built once per broadcast and shared across all " + + "stream batches within a task and across all tasks that consume the same broadcast " + + "on an executor. Reuse pins the physical build side to the broadcast side for the " + + "lifetime of each cached join, overriding the dynamic build-side selection " + + s"heuristic configured by ${JOIN_BUILD_SIDE.key}.") + .booleanConf + .createWithDefault(false) + val LOG_JOIN_CARDINALITY = conf("spark.rapids.sql.join.logCardinality") .doc("Enable logging of join cardinality statistics to help diagnose performance issues. " + "When enabled, logs task context, key data types, join condition, row counts, and " + @@ -3280,6 +3291,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val joinGathererSizeEstimateThreshold: Double = get(JOIN_GATHERER_SIZE_ESTIMATE_THRESHOLD) + lazy val broadcastHashTableReuse: Boolean = get(BROADCAST_HASH_TABLE_REUSE) + /** * Get join options based on the current configuration. * @param targetSize the target batch size in bytes to use for the join diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 2a05e486e9a..3dd0122edce 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -278,11 +278,18 @@ trait SpillableHandle extends StoreHandle with Logging { } /** - * Spillable handles that can be materialized on the device. + * Contract for handles tracked by the device spill store (see [[SpillableHandle]]). + */ +trait DeviceStoreHandle extends SpillableHandle { + def releaseSpilled(): Unit +} + +/** + * Spillable handles that can be materialized on the device and spilled to host. * @tparam T an auto closeable subclass. `dev` tracks an instance of this object, * on the device. */ -trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { +trait DeviceSpillableHandle[T <: AutoCloseable] extends DeviceStoreHandle { private[spill] var dev: Option[T] private[spill] override def spillable: Boolean = synchronized { @@ -305,7 +312,7 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { * free a device buffer that the worker thread isn't done with). * See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info. */ - def releaseSpilled(): Unit = { + override def releaseSpilled(): Unit = { releaseDeviceResource() } @@ -319,6 +326,173 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle { } } +object SharedRecomputableDeviceHandle { + final class Lease[T <: AutoCloseable] private[spill] ( + handle: SharedRecomputableDeviceHandle[T], + val resource: T) extends AutoCloseable { + private[this] var closed = false + + override def close(): Unit = synchronized { + if (closed) { + throw new IllegalStateException("Close called too many times on recomputable handle lease") + } + closed = true + handle.releasePin() + } + } + + def apply[T <: AutoCloseable]( + approxSizeInBytes: Long, + initialValue: T)( + rebuild: => T): SharedRecomputableDeviceHandle[T] = { + val handle = new SharedRecomputableDeviceHandle(approxSizeInBytes, initialValue, () => rebuild) + SpillFramework.stores.deviceStore.track(handle) + handle + } +} + +/** + * Handle for device-only state that is cheaper to recompute than to spill. + * + * When this handle is selected for spilling, it does not copy anything to host or disk. Instead + * it marks the current device state as evicted and returns `approxSizeInBytes` so the spill + * framework accounts for the freed device memory. The actual close of the evicted state is + * deferred to `releaseSpilled`, after device synchronization has completed. + * + * The protected device state does not expose reference counts, so spillability cannot follow the + * usual `getRefCount == 1` pattern used by cuDF tables and buffers. Instead, callers must pin the + * state through `acquire`, and spillability is derived from an application-level pin count. + */ +class SharedRecomputableDeviceHandle[T <: AutoCloseable] private[spill] ( + override val approxSizeInBytes: Long, + initialValue: T, + rebuild: () => T) extends DeviceStoreHandle with Logging { + import SharedRecomputableDeviceHandle.Lease + + private[spill] var dev: Option[T] = Some(initialValue) + private[this] var pendingRelease: Seq[T] = Seq.empty + private[this] var pinCount: Int = 0 + private[this] var rebuilding: Boolean = false + + private[spill] override def spillable: Boolean = synchronized { + super.spillable && dev.isDefined && pinCount == 0 + } + + def acquire(): Lease[T] = { + var materialized: Option[T] = None + var shouldBuild = false + while (materialized.isEmpty) { + shouldBuild = synchronized { + if (closed) { + throw new IllegalStateException("attempting to materialize a closed handle") + } else if (dev.isDefined) { + pinCount += 1 + materialized = dev + false + } else if (rebuilding) { + wait() + false + } else { + rebuilding = true + true + } + } + + if (shouldBuild) { + var rebuilt: Option[T] = None + try { + rebuilt = Some(rebuild()) + var shouldTrack = false + synchronized { + rebuilding = false + if (closed) { + notifyAll() + throw new IllegalStateException("attempting to materialize a closed handle") + } + dev = rebuilt + pinCount += 1 + materialized = rebuilt + shouldTrack = true + notifyAll() + } + if (shouldTrack) { + SpillFramework.stores.deviceStore.track(this) + } + } catch { + case t: Throwable => + rebuilt.foreach(_.close()) + synchronized { + rebuilding = false + notifyAll() + } + throw t + } + } + } + new Lease(this, materialized.get) + } + + private[spill] def releasePin(): Unit = synchronized { + if (pinCount <= 0) { + throw new IllegalStateException("releasePin called without a matching acquire") + } + pinCount -= 1 + } + + override def spill(): Long = { + var evicted: Option[T] = None + val thisThreadSpills = synchronized { + if (!closed && dev.isDefined && pinCount == 0 && !spilling) { + spilling = true + evicted = dev + dev = None + true + } else { + false + } + } + if (thisThreadSpills) { + SpillFramework.removeFromDeviceStore(this) + var shouldClose = false + executeSpill { + synchronized { + pendingRelease = pendingRelease ++ evicted.toSeq + spilling = false + shouldClose = closed + } + 0L + } + if (shouldClose) { + doClose() + } + approxSizeInBytes + } else { + 0L + } + } + + override def releaseSpilled(): Unit = { + val toClose = synchronized { + val release = pendingRelease + pendingRelease = Seq.empty + release + } + toClose.safeClose() + } + + override def doClose(): Unit = { + SpillFramework.removeFromDeviceStore(this) + val toClose = synchronized { + val current = dev + val release = pendingRelease + dev = None + pendingRelease = Seq.empty + current.toSeq ++ release.toSeq + } + toClose.safeClose() + } +} + /** * Spillable handles that can be materialized on the host. * @tparam T an auto closeable subclass. `host` tracks an instance of this object, @@ -1739,7 +1913,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None) override protected def spillNvtxRange: NvtxId = NvtxRegistry.DISK_SPILL } -class SpillableDeviceStore extends SpillableStore[DeviceSpillableHandle[_]] { +class SpillableDeviceStore extends SpillableStore[DeviceStoreHandle] { override protected def spillNvtxRange: NvtxId = NvtxRegistry.DEVICE_SPILL override def postSpill(plan: SpillPlan): Unit = { @@ -2152,7 +2326,7 @@ object SpillFramework extends Logging { // if the stores have already shut down, we don't want to create them here // so we use `storesInternal` directly in these remove functions. - private[spill] def removeFromDeviceStore(handle: DeviceSpillableHandle[_]): Unit = { + private[spill] def removeFromDeviceStore(handle: DeviceStoreHandle): Unit = { synchronized { Option(storesInternal).map(_.deviceStore) }.foreach(_.remove(handle)) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala index 2ca8ac0bf9e..a7f3f59d323 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,8 @@ trait ExecutionPlanCaptureCallbackBase { def assertContainsAnsiCast(df: DataFrame): Unit def assertNotContain(gpuPlan: SparkPlan, className: String): Unit def assertNotContain(df: DataFrame, gpuClass: String): Unit + def sumMetric(gpuPlan: SparkPlan, className: String, metricName: String): Long + def sumMetric(df: DataFrame, className: String, metricName: String): Long def assertDidFallBack(gpuPlan: SparkPlan, fallbackCpuClass: String): Unit def assertDidFallBack(df: DataFrame, fallbackCpuClass: String): Unit def assertDidFallBack(gpuPlans: Array[SparkPlan], fallbackCpuClass: String): Unit @@ -85,6 +87,12 @@ object ExecutionPlanCaptureCallback extends ExecutionPlanCaptureCallbackBase { override def assertNotContain(df: DataFrame, gpuClass: String): Unit = impl.assertNotContain(df, gpuClass) + override def sumMetric(gpuPlan: SparkPlan, className: String, metricName: String): Long = + impl.sumMetric(gpuPlan, className, metricName) + + override def sumMetric(df: DataFrame, className: String, metricName: String): Long = + impl.sumMetric(df, className, metricName) + override def assertDidFallBack(gpuPlan: SparkPlan, fallbackCpuClass: String): Unit = impl.assertDidFallBack(gpuPlan, fallbackCpuClass) @@ -128,4 +136,4 @@ class ExecutionPlanCaptureCallback extends QueryExecutionListener { trait AdaptiveSparkPlanHelperShim { def collect[B](p: SparkPlan)(pf: PartialFunction[SparkPlan, B]): Seq[B] -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala index ccc430c65b9..98840b5272c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala @@ -185,6 +185,21 @@ class ShimmedExecutionPlanCaptureCallbackImpl extends ExecutionPlanCaptureCallba assertNotContain(executedPlan, gpuClass) } + override def sumMetric(gpuPlan: SparkPlan, className: String, metricName: String): Long = { + val executedPlan = extractExecutedPlan(gpuPlan) + val matchingPlans = collectPlansMatching(executedPlan, p => PlanUtils.sameClass(p, className)) + assert(matchingPlans.nonEmpty, s"Could not find $className in the Spark plan\n$executedPlan") + matchingPlans.map { plan => + plan.metrics.getOrElse(metricName, { + throw new AssertionError(s"Could not find metric $metricName on plan node\n$plan") + }).value + }.sum + } + + override def sumMetric(df: DataFrame, className: String, metricName: String): Long = { + sumMetric(df.queryExecution.executedPlan, className, metricName) + } + override def assertContainsAnsiCast(df: DataFrame): Unit = { val executedPlan = extractExecutedPlan(df.queryExecution.executedPlan) assert(containsPlanMatching(executedPlan, @@ -253,5 +268,23 @@ class ShimmedExecutionPlanCaptureCallbackImpl extends ExecutionPlanCaptureCallba case p => p.children.exists(plan => containsPlanMatching(plan, f)) }.nonEmpty -} + private def collectPlansMatching(plan: SparkPlan, f: SparkPlan => Boolean): Seq[SparkPlan] = { + val matching = ArrayBuffer[SparkPlan]() + + def recurse(currentPlan: SparkPlan): Unit = currentPlan match { + case p: AdaptiveSparkPlanExec => recurse(p.executedPlan) + case p: QueryStageExec => recurse(p.plan) + case p: ReusedSubqueryExec => recurse(p.child) + case p: ReusedExchangeExec => recurse(p.child) + case p => + if (f(p)) { + matching += p + } + p.children.foreach(recurse) + } + + recurse(plan) + matching.toSeq + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala new file mode 100644 index 00000000000..6b7a81b12c0 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/BroadcastCachedBuildSide.scala @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution + +import ai.rapids.cudf.{DistinctHashJoin => CudfDistinctHashJoin, HashJoin => CudfHashJoin, Table} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, NvtxRegistry, SpillableColumnarBatch} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit +import com.nvidia.spark.rapids.spill.SharedRecomputableDeviceHandle + +import org.apache.spark.sql.vectorized.ColumnarBatch + +sealed trait CachedBuildSide extends AutoCloseable { + def buildStats: JoinBuildSideStats +} + +final class CachedHashJoin( + override val buildStats: JoinBuildSideStats, + val handle: SharedRecomputableDeviceHandle[CudfHashJoin]) extends CachedBuildSide { + override def close(): Unit = handle.close() +} + +final class CachedDistinctHashJoin( + override val buildStats: JoinBuildSideStats, + val handle: SharedRecomputableDeviceHandle[CudfDistinctHashJoin]) extends CachedBuildSide { + override def close(): Unit = handle.close() +} + +case class BroadcastCachedBuildSideKey( + projectedBuildKeys: Seq[String], + compareNullsEqual: Boolean, + filterOutNulls: Boolean) + +object BroadcastCachedBuildSide { + def key( + boundBuiltKeys: Seq[GpuExpression], + compareNullsEqual: Boolean, + filterOutNulls: Boolean): BroadcastCachedBuildSideKey = { + BroadcastCachedBuildSideKey( + boundBuiltKeys.map(_.canonicalized.sql), + compareNullsEqual, + filterOutNulls) + } + + private def withBuildKeys[T]( + broadcastBatch: SpillableColumnarBatch, + boundBuiltKeys: Seq[GpuExpression], + filterOutNulls: Boolean)(f: Table => T): T = { + def projectAndApply(cb: ColumnarBatch): T = { + withResource(GpuProjectExec.project(cb, boundBuiltKeys)) { buildKeys => + withResource(GpuColumnVector.from(buildKeys)) { buildKeysTable => + f(buildKeysTable) + } + } + } + if (filterOutNulls) { + val retainedBatch = broadcastBatch.incRefCount() + withResource(GpuHashJoin.filterNullsWithRetryAndClose(retainedBatch, boundBuiltKeys)) { + projectAndApply + } + } else { + val retainedBatch = broadcastBatch.incRefCount() + withRetryNoSplit(retainedBatch) { _ => + withResource(retainedBatch.getColumnarBatch()) { projectAndApply } + } + } + } + + private def newHashJoin(buildKeys: Table, compareNullsEqual: Boolean): CudfHashJoin = { + NvtxRegistry.BROADCAST_HASH_TABLE_BUILD { + new CudfHashJoin(buildKeys, compareNullsEqual) + } + } + + private def newDistinctHashJoin( + buildKeys: Table, + compareNullsEqual: Boolean): CudfDistinctHashJoin = { + NvtxRegistry.BROADCAST_HASH_TABLE_BUILD { + new CudfDistinctHashJoin(buildKeys, compareNullsEqual) + } + } + + /** + * cuDF's reusable hash join handles are safe for concurrent probes. The executor-wide cache + * therefore pins the live handle while a task is probing it and relies on + * `SharedRecomputableDeviceHandle` to track spillability with an application-level pin count. + */ + def create( + broadcastBatch: SpillableColumnarBatch, + boundBuiltKeys: Seq[GpuExpression], + compareNullsEqual: Boolean, + filterOutNulls: Boolean): CachedBuildSide = { + def buildHashJoin(): CudfHashJoin = { + withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => + newHashJoin(buildKeys, compareNullsEqual) + } + } + + def buildDistinctHashJoin(): CudfDistinctHashJoin = { + withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => + newDistinctHashJoin(buildKeys, compareNullsEqual) + } + } + + withBuildKeys(broadcastBatch, boundBuiltKeys, filterOutNulls) { buildKeys => + val stats = JoinBuildSideStats.fromTable(buildKeys) + val approxSizeInBytes = buildKeys.getDeviceMemorySize + if (stats.isDistinct) { + new CachedDistinctHashJoin( + stats, + SharedRecomputableDeviceHandle( + approxSizeInBytes, + newDistinctHashJoin(buildKeys, compareNullsEqual)) { + buildDistinctHashJoin() + }) + } else { + new CachedHashJoin( + stats, + SharedRecomputableDeviceHandle( + approxSizeInBytes, + newHashJoin(buildKeys, compareNullsEqual)) { + buildHashJoin() + }) + } + } + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 8420a482cdf..055120bd3ad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -80,9 +80,18 @@ class SerializeConcatHostBuffersDeserializeBatch( // used for memoization of deserialization to GPU on Executor @transient private var batchInternal: SpillableColumnarBatch = null + @transient private var cachedBuildSideCache: + mutable.HashMap[BroadcastCachedBuildSideKey, CachedBuildSide] = null private def maybeGpuBatch: Option[SpillableColumnarBatch] = Option(batchInternal) + private def cachedBuildSides: mutable.HashMap[BroadcastCachedBuildSideKey, CachedBuildSide] = { + if (cachedBuildSideCache == null) { + cachedBuildSideCache = mutable.HashMap.empty + } + cachedBuildSideCache + } + def batch: SpillableColumnarBatch = this.synchronized { maybeGpuBatch.getOrElse { NvtxRegistry.BROADCAST_MANIFEST_BATCH { @@ -148,6 +157,31 @@ class SerializeConcatHostBuffersDeserializeBatch( } } + def getCachedBuildSide( + boundBuiltKeys: Seq[GpuExpression], + compareNullsEqual: Boolean, + filterOutNulls: Boolean, + cacheBuilds: GpuMetric = NoopMetric, + cacheHits: GpuMetric = NoopMetric): CachedBuildSide = this.synchronized { + val cacheKey = BroadcastCachedBuildSide.key( + boundBuiltKeys, + compareNullsEqual, + filterOutNulls) + cachedBuildSides.get(cacheKey).map { cached => + cacheHits += 1 + cached + }.getOrElse { + cacheBuilds += 1 + val cached = BroadcastCachedBuildSide.create( + batch, + boundBuiltKeys, + compareNullsEqual, + filterOutNulls) + cachedBuildSides.put(cacheKey, cached) + cached + } + } + private def writeObject(out: ObjectOutputStream): Unit = { doWriteObject(out) } @@ -246,8 +280,10 @@ class SerializeConcatHostBuffersDeserializeBatch( */ def closeInternal(): Unit = this.synchronized { Seq(data, batchInternal).safeClose() + Option(cachedBuildSideCache).foreach(cache => cache.values.toSeq.safeClose()) data = null batchInternal = null + cachedBuildSideCache = null } @scala.annotation.nowarn("msg=method finalize in class Object is deprecated") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala index 9fad3be3153..48eb0041a6e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala @@ -119,7 +119,9 @@ abstract class GpuBroadcastHashJoinExecBase( override lazy val additionalMetrics: Map[String, GpuMetric] = Map( OP_TIME_LEGACY -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_OP_TIME_LEGACY), STREAM_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_STREAM_TIME), - JOIN_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_JOIN_TIME)) + JOIN_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_JOIN_TIME), + BUILD_SIDE_CACHE_BUILDS -> createMetric(DEBUG_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), + BUILD_SIDE_CACHE_HITS -> createMetric(DEBUG_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS)) override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) @@ -152,9 +154,12 @@ abstract class GpuBroadcastHashJoinExecBase( val opTime = gpuLongMetric(OP_TIME_LEGACY) val streamTime = gpuLongMetric(STREAM_TIME) val joinTime = gpuLongMetric(JOIN_TIME) + val buildSideCacheBuilds = gpuLongMetric(BUILD_SIDE_CACHE_BUILDS) + val buildSideCacheHits = gpuLongMetric(BUILD_SIDE_CACHE_HITS) val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) val joinOptions = RapidsConf.getJoinOptions(conf, targetSize) + val enableBuildSideReuse = RapidsConf.BROADCAST_HASH_TABLE_REUSE.get(conf) val broadcastRelation = broadcastExchange.executeColumnarBroadcast[Any]() @@ -167,6 +172,10 @@ abstract class GpuBroadcastHashJoinExecBase( broadcastRelation, buildSchema, new CollectTimeIterator(NvtxRegistry.BROADCAST_JOIN_STREAM, it, streamTime)) + val broadcastBatch = broadcastRelation.value match { + case batch: SerializeConcatHostBuffersDeserializeBatch => Some(batch) + case _ => None + } if (localIsNullAwareAntiJoin) { // This is to support the null-aware anti join for the LeftAnti join with // BuildRight. See the config "spark.sql.optimizeNullAwareAntiJoin". @@ -187,12 +196,18 @@ abstract class GpuBroadcastHashJoinExecBase( boundStreamKeys) } doJoin(builtBatch, nullFilteredStreamIter, joinOptions, numOutputRows, - numOutputBatches, opTime, joinTime) + numOutputBatches, opTime, joinTime, enableBuildSideReuse, + broadcastBatch = broadcastBatch, + buildSideCacheBuilds = buildSideCacheBuilds, + buildSideCacheHits = buildSideCacheHits) } } else { // builtBatch will be closed in doJoin doJoin(builtBatch, streamIter, joinOptions, numOutputRows, numOutputBatches, opTime, - joinTime) + joinTime, enableBuildSideReuse, + broadcastBatch = broadcastBatch, + buildSideCacheBuilds = buildSideCacheBuilds, + buildSideCacheHits = buildSideCacheHits) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index 2294d56b266..5a23f97b0b1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -15,7 +15,7 @@ */ package org.apache.spark.sql.rapids.execution -import ai.rapids.cudf.{ColumnView, DType, GatherMap, NullEquality, OutOfBoundsPolicy, Scalar, Table} +import ai.rapids.cudf.{ColumnView, DistinctHashJoin => CudfDistinctHashJoin, DType, GatherMap, HashJoin => CudfHashJoin, NullEquality, OutOfBoundsPolicy, Scalar, Table} import ai.rapids.cudf.ast.CompiledExpression import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} @@ -299,6 +299,82 @@ object JoinImpl { GatherMapsResult.makeFromLeft(leftRet) } + def innerHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(rightKeys.innerJoinGatherMaps(leftHashJoin, _)) + .getOrElse(rightKeys.innerJoinGatherMaps(leftHashJoin)) + GatherMapsResult(arrayRet(1), arrayRet(0)) + } + + def innerHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(leftKeys.innerJoinGatherMaps(rightHashJoin, _)) + .getOrElse(leftKeys.innerJoinGatherMaps(rightHashJoin)) + GatherMapsResult(arrayRet(0), arrayRet(1)) + } + + def leftOuterHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(leftKeys.leftJoinGatherMaps(rightHashJoin, _)) + .getOrElse(leftKeys.leftJoinGatherMaps(rightHashJoin)) + GatherMapsResult(arrayRet(0), arrayRet(1)) + } + + def rightOuterHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: CudfHashJoin, + outputRowCount: Option[Long] = None): GatherMapsResult = { + val arrayRet = outputRowCount.map(rightKeys.leftJoinGatherMaps(leftHashJoin, _)) + .getOrElse(rightKeys.leftJoinGatherMaps(leftHashJoin)) + GatherMapsResult(arrayRet(1), arrayRet(0)) + } + + def innerDistinctHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val arrayRet = rightKeys.innerJoinGatherMaps(leftHashJoin) + GatherMapsResult(arrayRet(1), arrayRet(0)) + } + + def innerDistinctHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val arrayRet = leftKeys.innerJoinGatherMaps(rightHashJoin) + GatherMapsResult(arrayRet(0), arrayRet(1)) + } + + def leftOuterDistinctHashJoinBuildRight( + leftKeys: Table, + rightHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val rightRet = leftKeys.leftDistinctJoinGatherMap(rightHashJoin) + GatherMapsResult.makeFromRight(rightRet) + } + + def rightOuterDistinctHashJoinBuildLeft( + rightKeys: Table, + leftHashJoin: CudfDistinctHashJoin): GatherMapsResult = { + val leftRet = rightKeys.leftDistinctJoinGatherMap(leftHashJoin) + GatherMapsResult.makeFromLeft(leftRet) + } + + def innerHashJoinBuildLeftRowCount(rightKeys: Table, leftHashJoin: CudfHashJoin): Long = + rightKeys.innerJoinRowCount(leftHashJoin) + + def innerHashJoinBuildRightRowCount(leftKeys: Table, rightHashJoin: CudfHashJoin): Long = + leftKeys.innerJoinRowCount(rightHashJoin) + + def leftOuterHashJoinBuildRightRowCount(leftKeys: Table, rightHashJoin: CudfHashJoin): Long = + leftKeys.leftJoinRowCount(rightHashJoin) + + def rightOuterHashJoinBuildLeftRowCount(rightKeys: Table, leftHashJoin: CudfHashJoin): Long = + rightKeys.leftJoinRowCount(leftHashJoin) + /** * Do an inner hash join with the build table as the left table. * @param leftKeys the left equality join keys @@ -1012,6 +1088,13 @@ case class JoinCardinalityStats( case class JoinBuildSideStats(streamMagnificationFactor: Double, isDistinct: Boolean) object JoinBuildSideStats { + def fromTable(buildKeys: Table): JoinBuildSideStats = { + val builtCount = buildKeys.distinctCount(NullEquality.EQUAL) + val isDistinct = builtCount == buildKeys.getRowCount + val magnificationFactor = buildKeys.getRowCount.toDouble / builtCount + JoinBuildSideStats(magnificationFactor, isDistinct) + } + def fromBatch(batch: ColumnarBatch, boundBuildKeys: Seq[GpuExpression]): JoinBuildSideStats = { // This is okay because the build keys must be deterministic @@ -1020,10 +1103,7 @@ object JoinBuildSideStats { // will be for each input row on the stream side. This does not take into account // the join type, data skew or even if the keys actually match. withResource(GpuColumnVector.from(buildKeys)) { keysTable => - val builtCount = keysTable.distinctCount(NullEquality.EQUAL) - val isDistinct = builtCount == buildKeys.numRows() - val magnificationFactor = buildKeys.numRows().toDouble / builtCount - JoinBuildSideStats(magnificationFactor, isDistinct) + fromTable(keysTable) } } } @@ -1040,12 +1120,15 @@ abstract class BaseHashJoinIterator( built: LazySpillableColumnarBatch, boundBuiltKeys: Seq[GpuExpression], buildStatsOpt: Option[JoinBuildSideStats], + cachedBuildSide: Option[CachedBuildSide], stream: Iterator[LazySpillableColumnarBatch], boundStreamKeys: Seq[GpuExpression], streamAttributes: Seq[Attribute], joinOptions: JoinOptions, joinType: JoinType, buildSide: GpuBuildSide, + enableBuildSideReuse: Boolean, + compareNullsEqual: Boolean, conditionForLogging: Option[Expression], opTime: GpuMetric, joinTime: GpuMetric) @@ -1059,9 +1142,10 @@ abstract class BaseHashJoinIterator( opTime = opTime, joinTime = joinTime) { // We can cache this because the build side is not changing - protected lazy val buildStats: JoinBuildSideStats = buildStatsOpt.getOrElse { + protected lazy val buildStats: JoinBuildSideStats = buildStatsOpt + .orElse(cachedBuildSide.map(_.buildStats)).getOrElse { joinType match { - case _: InnerLike | LeftOuter | RightOuter | FullOuter => + case _: InnerLike | LeftOuter | RightOuter | FullOuter | LeftSemi | LeftAnti => built.checkpoint() withRetryNoSplit { withRestoreOnRetry(built) { @@ -1074,6 +1158,60 @@ abstract class BaseHashJoinIterator( } } + private[this] var cachedBuildSideDisabled = !enableBuildSideReuse || cachedBuildSide.isEmpty + + private def canUseCachedBuildSide(expectedBuildSide: GpuBuildSide): Boolean = { + !cachedBuildSideDisabled && + expectedBuildSide == buildSide && + cachedBuildSide.isDefined + } + + protected def canUseCachedHashJoin(expectedBuildSide: GpuBuildSide): Boolean = { + canUseCachedBuildSide(expectedBuildSide) && cachedBuildSide.exists(_.isInstanceOf[CachedHashJoin]) + } + + protected def canUseCachedDistinctHashJoin(expectedBuildSide: GpuBuildSide): Boolean = { + canUseCachedBuildSide(expectedBuildSide) && + cachedBuildSide.exists(_.isInstanceOf[CachedDistinctHashJoin]) + } + + protected def withCachedHashJoin[T](expectedBuildSide: GpuBuildSide)(f: CudfHashJoin => T): Option[T] = { + if (!canUseCachedHashJoin(expectedBuildSide)) { + None + } else { + cachedBuildSide.flatMap { cached => + cached match { + case hashJoin: CachedHashJoin => + withResource(hashJoin.handle.acquire()) { lease => + Some(f(lease.resource)) + } + case _ => None + } + } + } + } + + protected def withCachedDistinctHashJoin[T]( + expectedBuildSide: GpuBuildSide)(f: CudfDistinctHashJoin => T): Option[T] = { + if (!canUseCachedDistinctHashJoin(expectedBuildSide)) { + None + } else { + cachedBuildSide.flatMap { cached => + cached match { + case distinctHashJoin: CachedDistinctHashJoin => + withResource(distinctHashJoin.handle.acquire()) { lease => + Some(f(lease.resource)) + } + case _ => None + } + } + } + } + + protected def disableCachedBuildSide(): Unit = { + cachedBuildSideDisabled = true + } + /** * Check if sort join is supported for the given key expressions. * Sort join does not support ARRAY or STRUCT types in join keys. @@ -1281,7 +1419,7 @@ abstract class BaseHashJoinIterator( withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => // ensure that the build data can be spilled built.allowSpilling() - joinGatherer(builtKeys, built, streamBatch) + joinGatherer(builtKeys, built, streamBatch, numJoinRows) } } } @@ -1294,6 +1432,7 @@ abstract class BaseHashJoinIterator( || joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter => + disableCachedBuildSide() // Because this is just an estimate, it is possible for us to get this wrong, so // make sure we at least split the batch in half. val numBatches = Math.max(2, estimatedNumBatches(spillOnlyCb)) @@ -1317,16 +1456,18 @@ abstract class BaseHashJoinIterator( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] private def joinGathererLeftRight( leftKeys: ColumnarBatch, leftData: LazySpillableColumnarBatch, rightKeys: ColumnarBatch, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { withResource(GpuColumnVector.from(leftKeys)) { leftKeysTab => withResource(GpuColumnVector.from(rightKeys)) { rightKeysTab => - joinGathererLeftRight(leftKeysTab, leftData, rightKeysTab, rightData) + joinGathererLeftRight(leftKeysTab, leftData, rightKeysTab, rightData, numJoinRows) } } } @@ -1335,23 +1476,33 @@ abstract class BaseHashJoinIterator( buildKeys: ColumnarBatch, buildData: LazySpillableColumnarBatch, streamKeys: ColumnarBatch, - streamData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + streamData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { buildSide match { case GpuBuildLeft => - joinGathererLeftRight(buildKeys, buildData, streamKeys, streamData) + joinGathererLeftRight(buildKeys, buildData, streamKeys, streamData, numJoinRows) case GpuBuildRight => - joinGathererLeftRight(streamKeys, streamData, buildKeys, buildData) + joinGathererLeftRight(streamKeys, streamData, buildKeys, buildData, numJoinRows) } } private def joinGatherer( buildKeys: ColumnarBatch, buildData: LazySpillableColumnarBatch, - streamCb: LazySpillableColumnarBatch): Option[JoinGatherer] = { + streamCb: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { withResource(GpuProjectExec.project(streamCb.getBatch, boundStreamKeys)) { streamKeys => // ensure we make the stream side spillable again streamCb.allowSpilling() - joinGatherer(buildKeys, LazySpillableColumnarBatch.spillOnly(buildData), streamKeys, streamCb) + joinGatherer(buildKeys, LazySpillableColumnarBatch.spillOnly(buildData), streamKeys, streamCb, + numJoinRows) + } + } + + override def close(): Unit = { + if (!closed) { + disableCachedBuildSide() + super.close() } } @@ -1384,25 +1535,58 @@ class HashJoinIterator( val compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - private val joinTime: GpuMetric) + private val joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) extends BaseHashJoinIterator( built, boundBuiltKeys, buildStatsOpt, + cachedBuildSide, stream, boundStreamKeys, streamAttributes, joinOptions, joinType, buildSide, + enableBuildSideReuse, + compareNullsEqual, conditionForLogging, opTime = opTime, joinTime = joinTime) { + + override def computeNumJoinRows(cb: LazySpillableColumnarBatch): Long = { + lazy val fallback = super.computeNumJoinRows(cb) + if (buildStats.isDistinct) { + fallback + } else { + cb.checkpoint() + try { + withRetryNoSplit { + withRestoreOnRetry(cb) { + withResource(GpuProjectExec.project(cb.getBatch, boundStreamKeys)) { streamKeys => + withResource(GpuColumnVector.from(streamKeys)) { streamKeysTable => + exactNumJoinRows(streamKeysTable).getOrElse(fallback) + } + } + } + } + } catch { + case _: OutOfMemoryError | _: GpuOOM => + disableCachedBuildSide() + fallback + } finally { + cb.allowSpilling() + } + } + } + override protected def joinGathererLeftRight( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { NvtxIdWithMetrics(NvtxRegistry.HASH_JOIN_GATHER_MAP, joinTime) { // hack to work around unique_join not handling empty tables if (joinType.isInstanceOf[InnerLike] && @@ -1415,30 +1599,12 @@ class HashJoinIterator( val maps = if (buildStats.isDistinct) { // Distinct join optimizations (highest priority, overrides strategy) - logJoinCardinality(leftKeys, rightKeys, "distinct") - val result = joinType match { - case LeftOuter => - val rightRet = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual) - GatherMapsResult.makeFromRight(rightRet) - case RightOuter => - val leftRet = rightKeys.leftDistinctJoinGatherMap(leftKeys, compareNullsEqual) - GatherMapsResult.makeFromLeft(leftRet) - case _: InnerLike => - val arrayRet = if (buildSide == GpuBuildRight) { - leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual) - } else { - rightKeys.innerDistinctJoinGatherMaps(leftKeys, compareNullsEqual).reverse - } - GatherMapsResult(arrayRet(0), arrayRet(1)) - case _ => - // Fall through to strategy-based dispatching for non-outer joins - computeNonDistinctJoin(leftKeys, rightKeys, leftData, rightData) - } + val result = computeDistinctJoin(leftKeys, rightKeys) logJoinCompletion() result } else { // Non-distinct joins: use strategy-based dispatching - computeNonDistinctJoin(leftKeys, rightKeys, leftData, rightData) + computeNonDistinctJoin(leftKeys, rightKeys, numJoinRows) } makeGatherer(maps, leftData, rightData, joinType) @@ -1446,11 +1612,124 @@ class HashJoinIterator( } } + private def exactNumJoinRows(streamKeys: Table): Option[Long] = { + joinType match { + case _: InnerLike => + buildSide match { + case GpuBuildRight => + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.innerHashJoinBuildRightRowCount(streamKeys, hashJoin) + } + case GpuBuildLeft => + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.innerHashJoinBuildLeftRowCount(streamKeys, hashJoin) + } + } + case LeftOuter => + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.leftOuterHashJoinBuildRightRowCount(streamKeys, hashJoin) + } + case RightOuter => + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.rightOuterHashJoinBuildLeftRowCount(streamKeys, hashJoin) + } + case _ => + None + } + } + + private def cachedGenericInnerJoin( + leftKeys: Table, + rightKeys: Table, + outputRowCount: Option[Long]): Option[GatherMapsResult] = { + buildSide match { + case GpuBuildRight => + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin, outputRowCount) + } + case GpuBuildLeft => + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin, outputRowCount) + } + } + } + + private def reusedGenericLeftSemi( + leftKeys: Table): Option[GatherMapsResult] = { + withCachedHashJoin(GpuBuildRight) { hashJoin => + withResource(JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) { innerMaps => + JoinImpl.makeLeftSemi(innerMaps, leftKeys.getRowCount.toInt) + } + } + } + + private def reusedGenericLeftAnti( + leftKeys: Table): Option[GatherMapsResult] = { + withCachedHashJoin(GpuBuildRight) { hashJoin => + withResource(JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin)) { innerMaps => + JoinImpl.makeLeftAnti(innerMaps, leftKeys.getRowCount.toInt) + } + } + } + + private def computeDistinctJoin( + leftKeys: Table, + rightKeys: Table): GatherMapsResult = { + val reused = withCachedDistinctHashJoin(buildSide) { distinctHashJoin => + logJoinCardinality(leftKeys, rightKeys, "distinct (reused)") + joinType match { + case LeftOuter => + JoinImpl.leftOuterDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + case RightOuter => + JoinImpl.rightOuterDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + case LeftSemi => + withResource(JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin)) { innerMaps => + JoinImpl.makeLeftSemi(innerMaps, leftKeys.getRowCount.toInt) + } + case LeftAnti => + withResource(JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin)) { innerMaps => + JoinImpl.makeLeftAnti(innerMaps, leftKeys.getRowCount.toInt) + } + case _: InnerLike => + if (buildSide == GpuBuildRight) { + JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + } else { + JoinImpl.innerDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + } + case _ => + computeDistinctJoinWithoutReuse(leftKeys, rightKeys) + } + } + reused.getOrElse(computeDistinctJoinWithoutReuse(leftKeys, rightKeys)) + } + + private def computeDistinctJoinWithoutReuse( + leftKeys: Table, + rightKeys: Table): GatherMapsResult = { + logJoinCardinality(leftKeys, rightKeys, "distinct") + joinType match { + case LeftOuter => + val rightRet = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual) + GatherMapsResult.makeFromRight(rightRet) + case RightOuter => + val leftRet = rightKeys.leftDistinctJoinGatherMap(leftKeys, compareNullsEqual) + GatherMapsResult.makeFromLeft(leftRet) + case _: InnerLike => + val arrayRet = if (buildSide == GpuBuildRight) { + leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual) + } else { + rightKeys.innerDistinctJoinGatherMaps(leftKeys, compareNullsEqual).reverse + } + GatherMapsResult(arrayRet(0), arrayRet(1)) + case _ => + computeNonDistinctJoin(leftKeys, rightKeys, None) + } + } + private def computeNonDistinctJoin( leftKeys: Table, rightKeys: Table, - leftData: LazySpillableColumnarBatch, - rightData: LazySpillableColumnarBatch): GatherMapsResult = { + numJoinRows: Option[Long]): GatherMapsResult = { // Apply heuristics to select the effective strategy val effectiveStrategy = JoinStrategy.selectStrategy( joinOptions.strategy, @@ -1463,7 +1742,7 @@ class HashJoinIterator( effectiveStrategy match { case JoinStrategy.INNER_HASH_WITH_POST => // Use composable JNI APIs: inner join -> convert to target join type - computeNonCondInnerHashWithPost(leftKeys, rightKeys) + computeNonCondInnerHashWithPost(leftKeys, rightKeys, numJoinRows) case JoinStrategy.INNER_SORT_WITH_POST => // Check if sort join is supported (no ARRAY/STRUCT types) val leftKeysSupported = isSortJoinSupported(boundBuiltKeys) @@ -1477,17 +1756,18 @@ class HashJoinIterator( s"ARRAY or STRUCT types which are not supported for sort joins. " + s"Falling back to INNER_HASH_WITH_POST strategy.") } - computeNonCondInnerHashWithPost(leftKeys, rightKeys, isFallback = true) + computeNonCondInnerHashWithPost(leftKeys, rightKeys, numJoinRows, isFallback = true) } case _ => // Use existing hash join methods (for HASH_ONLY and when AUTO doesn't trigger heuristics) - computeWithHashJoin(leftKeys, rightKeys) + computeWithHashJoin(leftKeys, rightKeys, numJoinRows) } } private def computeNonCondInnerHashWithPost( leftKeys: Table, rightKeys: Table, + numJoinRows: Option[Long], isFallback: Boolean = false): GatherMapsResult = { val implName = if (isFallback) { "INNER_HASH_WITH_POST (fallback from INNER_SORT_WITH_POST)" @@ -1496,8 +1776,10 @@ class HashJoinIterator( } logJoinCardinality(leftKeys, rightKeys, implName) - val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + val innerMaps = cachedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, buildSide) + } val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1525,21 +1807,32 @@ class HashJoinIterator( private def computeWithHashJoin( leftKeys: Table, - rightKeys: Table): GatherMapsResult = { + rightKeys: Table, + numJoinRows: Option[Long]): GatherMapsResult = { logJoinCardinality(leftKeys, rightKeys, "hash join") val result = joinType match { case LeftOuter => - JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin, numJoinRows) + } + .getOrElse(JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case RightOuter => - JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin, numJoinRows) + } + .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case _: InnerLike => - JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, buildSide) + cachedGenericInnerJoin(leftKeys, rightKeys, numJoinRows).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, buildSide) + } case LeftSemi => - JoinImpl.leftSemiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + reusedGenericLeftSemi(leftKeys) + .getOrElse(JoinImpl.leftSemiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case LeftAnti => - JoinImpl.leftAntiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + reusedGenericLeftAnti(leftKeys) + .getOrElse(JoinImpl.leftAntiHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case _ => throw new NotImplementedError(s"Join Type ${joinType.getClass} is not currently" + s" supported") @@ -1567,17 +1860,22 @@ class ConditionalHashJoinIterator( compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - joinTime: GpuMetric) + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) extends BaseHashJoinIterator( built, boundBuiltKeys, buildStatsOpt, + cachedBuildSide, stream, boundStreamKeys, streamAttributes, joinOptions, joinType, buildSide, + enableBuildSideReuse, + compareNullsEqual, conditionForLogging, opTime = opTime, joinTime = joinTime) { @@ -1592,7 +1890,8 @@ class ConditionalHashJoinIterator( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { val nullEquality = if (compareNullsEqual) NullEquality.EQUAL else NullEquality.UNEQUAL NvtxIdWithMetrics(NvtxRegistry.HASH_JOIN_GATHER_MAP, joinTime) { withResource(GpuColumnVector.from(leftData.getBatch)) { leftTable => @@ -1783,17 +2082,22 @@ class HashJoinStreamSideIterator( compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - joinTime: GpuMetric) + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) extends BaseHashJoinIterator( built, boundBuiltKeys, buildStatsOpt, + cachedBuildSide, stream, boundStreamKeys, streamAttributes, joinOptions, joinType, buildSide, + enableBuildSideReuse, + compareNullsEqual, conditionForLogging, opTime = opTime, joinTime = joinTime) { @@ -1828,11 +2132,81 @@ class HashJoinStreamSideIterator( private[this] var builtSideTracker: Option[SpillableColumnarBatch] = buildSideTrackerInit + private def cachedUnconditionalInnerHashJoin( + leftKeys: Table, + rightKeys: Table): Option[GatherMapsResult] = { + cudfBuildSide match { + case GpuBuildRight => + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.innerHashJoinBuildRight(leftKeys, hashJoin) + } + case GpuBuildLeft => + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.innerHashJoinBuildLeft(rightKeys, hashJoin) + } + } + } + + private def computeDistinctUnconditionalJoin( + leftKeys: Table, + rightKeys: Table, + originalJoinType: Option[JoinType]): GatherMapsResult = { + val implName = if (canUseCachedDistinctHashJoin(cudfBuildSide)) { + s"distinct (outer: $joinType, reused)" + } else { + s"distinct (outer: $joinType)" + } + logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) + + val result = withCachedDistinctHashJoin(cudfBuildSide) { distinctHashJoin => + subJoinType match { + case LeftOuter => + JoinImpl.leftOuterDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + case RightOuter => + JoinImpl.rightOuterDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + case Inner => + if (cudfBuildSide == GpuBuildRight) { + JoinImpl.innerDistinctHashJoinBuildRight(leftKeys, distinctHashJoin) + } else { + JoinImpl.innerDistinctHashJoinBuildLeft(rightKeys, distinctHashJoin) + } + case t => + throw new IllegalStateException(s"unsupported join type: $t") + } + }.getOrElse { + subJoinType match { + case LeftOuter => + val rightRet = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual) + GatherMapsResult.makeFromRight(rightRet) + case RightOuter => + val leftRet = rightKeys.leftDistinctJoinGatherMap(leftKeys, compareNullsEqual) + GatherMapsResult.makeFromLeft(leftRet) + case Inner => + val arrayRet = if (cudfBuildSide == GpuBuildRight) { + leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual) + } else { + rightKeys.innerDistinctJoinGatherMaps(leftKeys, compareNullsEqual).reverse + } + GatherMapsResult(arrayRet(0), arrayRet(1)) + case t => + throw new IllegalStateException(s"unsupported join type: $t") + } + } + logJoinCompletion() + result + } + private def unconditionalJoinGatherMaps( leftKeys: Table, rightKeys: Table): GatherMapsResult = { // Pass the original joinType if it was transformed to subJoinType val originalJoinType = if (joinType != subJoinType) Some(joinType) else None + // The distinct outer path only produces a single gather map for LeftOuter/RightOuter, + // but HashJoinStreamSideIterator always needs both maps to update tracking state. + if (buildStats.isDistinct && subJoinType == Inner) { + return computeDistinctUnconditionalJoin(leftKeys, rightKeys, originalJoinType) + } + // Apply heuristics to select the effective strategy for unconditional joins // Note: subJoinType is used for strategy selection since that's what we're actually executing val effectiveStrategy = JoinStrategy.selectStrategy( @@ -1879,8 +2253,10 @@ class HashJoinStreamSideIterator( } logJoinCardinality(leftKeys, rightKeys, implName, originalJoinType) - val innerMaps = JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + val innerMaps = cachedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, cudfBuildSide) + } val leftRowCount = leftKeys.getRowCount val rightRowCount = rightKeys.getRowCount @@ -1918,12 +2294,20 @@ class HashJoinStreamSideIterator( val result = subJoinType match { case LeftOuter => - JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual) + withCachedHashJoin(GpuBuildRight) { hashJoin => + JoinImpl.leftOuterHashJoinBuildRight(leftKeys, hashJoin) + } + .getOrElse(JoinImpl.leftOuterHashJoinBuildRight(leftKeys, rightKeys, compareNullsEqual)) case RightOuter => - JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual) + withCachedHashJoin(GpuBuildLeft) { hashJoin => + JoinImpl.rightOuterHashJoinBuildLeft(rightKeys, hashJoin) + } + .getOrElse(JoinImpl.rightOuterHashJoinBuildLeft(leftKeys, rightKeys, compareNullsEqual)) case Inner => - JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, - joinOptions.buildSideSelection, cudfBuildSide) + cachedUnconditionalInnerHashJoin(leftKeys, rightKeys).getOrElse { + JoinImpl.innerHashJoin(leftKeys, rightKeys, compareNullsEqual, + joinOptions.buildSideSelection, cudfBuildSide) + } case t => throw new IllegalStateException(s"unsupported join type: $t") } @@ -2085,7 +2469,8 @@ class HashJoinStreamSideIterator( leftKeys: Table, leftData: LazySpillableColumnarBatch, rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { + rightData: LazySpillableColumnarBatch, + numJoinRows: Option[Long]): Option[JoinGatherer] = { NvtxIdWithMetrics(NvtxRegistry.FULL_HASH_JOIN_GATHER_MAP, joinTime) { val maps = lazyCompiledCondition.map { lazyCondition => conditionalJoinGatherMaps(leftKeys, leftData, rightKeys, rightData, lazyCondition) @@ -2242,12 +2627,16 @@ class HashOuterJoinIterator( compareNullsEqual: Boolean, // This is a workaround to how cudf support joins for structs conditionForLogging: Option[Expression], opTime: GpuMetric, - joinTime: GpuMetric) extends Iterator[ColumnarBatch] with TaskAutoCloseableResource { + joinTime: GpuMetric, + enableBuildSideReuse: Boolean = false, + cachedBuildSide: Option[CachedBuildSide] = None) + extends Iterator[ColumnarBatch] with TaskAutoCloseableResource { private val streamJoinIter = new HashJoinStreamSideIterator(joinType, built, boundBuiltKeys, - buildStats, buildSideTrackerInit, stream, boundStreamKeys, streamAttributes, + buildStats, buildSideTrackerInit, stream, boundStreamKeys, streamAttributes, lazyCompiledCondition, - joinOptions, buildSide, compareNullsEqual, conditionForLogging, opTime, joinTime) + joinOptions, buildSide, compareNullsEqual, conditionForLogging, opTime, joinTime, + enableBuildSideReuse, cachedBuildSide) private var finalBatch: Option[ColumnarBatch] = None @@ -2551,10 +2940,24 @@ trait GpuHashJoin extends GpuJoinExec { numOutputRows: GpuMetric, numOutputBatches: GpuMetric, opTime: GpuMetric, - joinTime: GpuMetric): Iterator[ColumnarBatch] = { + joinTime: GpuMetric, + enableBuildSideReuse: Boolean, + broadcastBatch: Option[SerializeConcatHostBuffersDeserializeBatch] = None, + buildSideCacheBuilds: GpuMetric = NoopMetric, + buildSideCacheHits: GpuMetric = NoopMetric): Iterator[ColumnarBatch] = { val filterOutNull = GpuHashJoin.buildSideNeedsNullFilter(joinType, compareNullsEqual, buildSide, buildKeys) + val cachedBuildSide = if (enableBuildSideReuse) { + broadcastBatch.map(_.getCachedBuildSide( + boundBuildKeys, + compareNullsEqual, + filterOutNull, + buildSideCacheBuilds, + buildSideCacheHits)) + } else { + None + } val nullFiltered = if (filterOutNull) { val sb = closeOnExcept(builtBatch)( @@ -2597,10 +3000,13 @@ trait GpuHashJoin extends GpuJoinExec { val lazyCond = boundConditionLeftRight.map { cond => LazyCompiledCondition(cond, left.output.size, right.output.size) } - new HashOuterJoinIterator(joinType, spillableBuiltBatch, boundBuildKeys, None, None, + new HashOuterJoinIterator(joinType, spillableBuiltBatch, boundBuildKeys, None, + None, lazyStream, boundStreamKeys, streamedPlan.output, lazyCond, joinOptions, buildSide, - compareNullsEqual, condition, opTime, joinTime) + compareNullsEqual, condition, opTime, joinTime, + enableBuildSideReuse = enableBuildSideReuse, + cachedBuildSide = cachedBuildSide) case _ => if (boundConditionLeftRight.isDefined) { // ConditionalHashJoinIterator will close the LazyCompiledCondition @@ -2612,11 +3018,15 @@ trait GpuHashJoin extends GpuJoinExec { new ConditionalHashJoinIterator(spillableBuiltBatch, boundBuildKeys, None, lazyStream, boundStreamKeys, streamedPlan.output, lazyCond, joinOptions, joinType, buildSide, - compareNullsEqual, condition, opTime, joinTime) + compareNullsEqual, condition, opTime, joinTime, + enableBuildSideReuse = enableBuildSideReuse, + cachedBuildSide = cachedBuildSide) } else { new HashJoinIterator(spillableBuiltBatch, boundBuildKeys, None, lazyStream, boundStreamKeys, streamedPlan.output, joinOptions, - joinType, buildSide, compareNullsEqual, condition, opTime, joinTime) + joinType, buildSide, compareNullsEqual, condition, opTime, joinTime, + enableBuildSideReuse = enableBuildSideReuse, + cachedBuildSide = cachedBuildSide) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index 12faa151a4b..509ff7fb385 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -607,7 +607,7 @@ trait GpuSubPartitionHashJoin extends Logging { self: GpuHashJoin => } // Leverage the original join iterators val joinIter = doJoin(buildCb, streamIter, joinOptions, - numOutputRows, numOutputBatches, opTime, joinTime) + numOutputRows, numOutputBatches, opTime, joinTime, enableBuildSideReuse = false) Some(joinIter) } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 395c5b19bfc..1c262ea9960 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -96,6 +96,8 @@ case class GpuBroadcastHashJoinExec( NUM_INPUT_ROWS -> createMetric(DEBUG_LEVEL, DESCRIPTION_NUM_INPUT_ROWS), NUM_INPUT_BATCHES -> createMetric(DEBUG_LEVEL, DESCRIPTION_NUM_INPUT_BATCHES), CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME), + BUILD_SIDE_CACHE_BUILDS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_BUILDS), + BUILD_SIDE_CACHE_HITS -> createMetric(MODERATE_LEVEL, DESCRIPTION_BUILD_SIDE_CACHE_HITS), ) override def requiredChildDistribution: Seq[Distribution] = { @@ -164,6 +166,7 @@ case class GpuBroadcastHashJoinExec( val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) val joinOptions = RapidsConf.getJoinOptions(conf, targetSize) + val enableBuildSideReuse = RapidsConf.BROADCAST_HASH_TABLE_REUSE.get(conf) // Get all the broadcast data from the shuffle coalesced into a single partition val partitionSpecs = Seq(CoalescedPartitionSpec(0, shuffleExchange.numPartitions)) @@ -202,12 +205,12 @@ case class GpuBroadcastHashJoinExec( boundStreamKeys) } doJoin(builtBatch, nullFilteredStreamIter, joinOptions, numOutputRows, - numOutputBatches, opTime, joinTime) + numOutputBatches, opTime, joinTime, enableBuildSideReuse) } } else { // builtBatch will be closed in doJoin doJoin(builtBatch, streamIter, joinOptions, numOutputRows, numOutputBatches, opTime, - joinTime) + joinTime, enableBuildSideReuse) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index 70fbf438f1d..20e6c54f7f0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,63 @@ package com.nvidia.spark.rapids import com.nvidia.spark.rapids.TestUtils.findOperator import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuHashJoin} class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { + private def broadcastReuseConf: SparkConf = new SparkConf() + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.rapids.sql.join.broadcastHashTable.reuse", "true") + .set("spark.rapids.sql.batchSizeBytes", "1") + + private def streamedProbeDf(spark: SparkSession): DataFrame = + spark.range(0, 128).selectExpr( + "CAST(id % 8 AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + + private def distinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(id AS INT) AS join_key", + "CAST(id * 10 AS INT) AS build_value") + + private def nonDistinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 16).selectExpr( + "CAST(id % 4 AS INT) AS join_key", + "CAST(id AS INT) AS build_value") + + private def largerNonDistinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 1024).selectExpr( + "CAST(id % 4 AS INT) AS join_key", + "CAST(id AS INT) AS build_value") + + private def nullableProbeDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 3 " + + "WHEN 5 THEN 4 " + + "WHEN 6 THEN 5 " + + "ELSE 8 END AS INT) AS join_key", + "CAST(id AS INT) AS probe_value") + + private def nullableDistinctBuildDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 3 " + + "WHEN 5 THEN 4 " + + "WHEN 6 THEN 6 " + + "ELSE 9 END AS INT) AS join_key", + "CAST(id * 10 AS INT) AS build_value") test("broadcast hint isn't propagated after a join") { val conf = new SparkConf() @@ -71,4 +124,123 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { } }) } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct inner build right", + streamedProbeDf, + distinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct left outer build right", + streamedProbeDf, + distinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "left") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct right outer build left", + distinctBuildDf, + streamedProbeDf, + conf = broadcastReuseConf) { + (build, probe) => broadcast(build).join(probe, Seq("join_key"), "right") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct full outer build right", + nullableProbeDf, + nullableDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "fullouter") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct inner build right", + streamedProbeDf, + nonDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct inner build right with auto build-side selection", + streamedProbeDf, + largerNonDistinctBuildDf, + conf = broadcastReuseConf.clone().set("spark.rapids.sql.join.buildSide", "AUTO")) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct inner build left", + nonDistinctBuildDf, + streamedProbeDf, + conf = broadcastReuseConf) { + (build, probe) => broadcast(build).join(probe, Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct left semi build right", + streamedProbeDf, + nonDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "leftsemi") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse non-distinct left anti build right", + streamedProbeDf, + nonDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "leftanti") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct inner nullable keys build right", + nullableProbeDf, + nullableDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "inner") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse distinct left outer nullable keys build right", + nullableProbeDf, + nullableDistinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => probe.join(broadcast(build), Seq("join_key"), "left") + } + + IGNORE_ORDER_testSparkResultsAreEqual2( + "broadcast hash join reuse conditional inner build right", + streamedProbeDf, + distinctBuildDf, + conf = broadcastReuseConf) { + (probe, build) => + probe.join(broadcast(build), + probe("join_key") === build("join_key") && + probe("probe_value") <= build("build_value"), "inner") + } + + test("broadcast hash join reuse same broadcast in multiple joins plan") { + val conf = broadcastReuseConf.clone().set("spark.sql.exchange.reuse", "true") + withGpuSparkSession(spark => { + val probe = streamedProbeDf(spark) + val build = broadcast(distinctBuildDf(spark)) + val joined = probe + .join(build, Seq("join_key"), "inner") + .select("join_key", "probe_value") + .join(build, Seq("join_key"), "inner") + .select("join_key", "probe_value") + + joined.collect() + val plan = joined.queryExecution.executedPlan + val bhjCount = PlanUtils.findOperators(plan, _.isInstanceOf[GpuBroadcastHashJoinExec]) + val reusedExchanges = PlanUtils.findOperators(plan, _.isInstanceOf[ReusedExchangeExec]) + assertResult(2)(bhjCount.size) + assert(reusedExchanges.nonEmpty) + }, conf) + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala index 2f69cf6ca61..383cbf86b7c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint} @@ -45,6 +46,39 @@ class JoinsSuite extends SparkQueryCompareTestSuite { .set("spark.sql.join.preferSortMergeJoin", "false") .set("spark.sql.shuffle.partitions", "2") // hack to try and work around bug in cudf + private def shuffledDistinctOuterConf: SparkConf = new SparkConf() + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.sql.join.preferSortMergeJoin", "false") + .set("spark.sql.shuffle.partitions", "2") + .set("spark.rapids.sql.batchSizeBytes", "1") + + private def distinctNullableLeftDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 3 " + + "WHEN 5 THEN 4 " + + "WHEN 6 THEN 6 " + + "ELSE 8 END AS INT) AS join_key", + "CAST(id AS INT) AS left_value") + + private def distinctNullableRightDf(spark: SparkSession): DataFrame = + spark.range(0, 8).selectExpr( + "CAST(CASE CAST(id AS INT) " + + "WHEN 0 THEN NULL " + + "WHEN 1 THEN 0 " + + "WHEN 2 THEN 1 " + + "WHEN 3 THEN 2 " + + "WHEN 4 THEN 4 " + + "WHEN 5 THEN 5 " + + "WHEN 6 THEN 7 " + + "ELSE 9 END AS INT) AS join_key", + "CAST(id * 10 AS INT) AS right_value") + IGNORE_ORDER_testSparkResultsAreEqual2("Test hash join", longsDf, biggerLongsDf, conf = shuffledJoinConf) { (A, B) => A.join(B, A("longs") === B("longs")) @@ -70,6 +104,25 @@ class JoinsSuite extends SparkQueryCompareTestSuite { (A, B) => A.join(B, A("longs") === B("longs"), "FullOuter") } + IGNORE_ORDER_testSparkResultsAreEqual2( + "Test shuffled distinct full join with nullable keys", + distinctNullableLeftDf, + distinctNullableRightDf, + conf = shuffledDistinctOuterConf) { + (left, right) => left.repartition(2).join(right.repartition(2), Seq("join_key"), "FullOuter") + } + + test("Test shuffled distinct full join with nullable keys uses sized hash join") { + withGpuSparkSession(spark => { + val left = distinctNullableLeftDf(spark).repartition(2) + val right = distinctNullableRightDf(spark).repartition(2) + val joined = left.join(right, Seq("join_key"), "FullOuter") + joined.collect() + assert(TestUtils.findOperator(joined.queryExecution.executedPlan, + _.isInstanceOf[GpuShuffledSizedHashJoinExec[_]]).isDefined) + }, shuffledDistinctOuterConf) + } + IGNORE_ORDER_testSparkResultsAreEqual2("Test cross join", longsDf, biggerLongsDf, conf = shuffledJoinConf) { (A, B) => A.join(B.hint("broadcast"), A("longs") < B("longs"), "Cross") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala new file mode 100644 index 00000000000..3ac82a92c19 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/spill/SharedRecomputableDeviceHandleSuite.scala @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.spill + +import java.util.concurrent.{Callable, CountDownLatch, Executors, TimeUnit} + +import scala.collection.mutable.ArrayBuffer + +import com.nvidia.spark.rapids.Arm.withResource + +class SharedRecomputableDeviceHandleSuite extends SpillUnitTestBase { + private class TestResource(val id: Int, closedIds: ArrayBuffer[Int]) extends AutoCloseable { + private var closed = false + + def isClosed: Boolean = closed + + override def close(): Unit = { + if (closed) { + throw new IllegalStateException(s"resource $id closed twice") + } + closed = true + closedIds += id + } + } + + test("recomputable handle uses pin count for spillability and rebuilds after spill") { + val closedIds = ArrayBuffer[Int]() + var buildCount = 0 + + def buildResource(): TestResource = { + buildCount += 1 + new TestResource(buildCount, closedIds) + } + + withResource( + new SharedRecomputableDeviceHandle(1024L, buildResource(), () => buildResource())) { + handle => + SpillFramework.stores.deviceStore.track(handle) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + assert(handle.spillable) + + withResource(handle.acquire()) { lease => + assertResult(1)(lease.resource.id) + assert(!handle.spillable) + assertResult(0L)(SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) + } + + assert(handle.spillable) + assertResult(handle.approxSizeInBytes)( + SpillFramework.stores.deviceStore.spill(handle.approxSizeInBytes)) + assertResult(Seq(1))(closedIds.toSeq) + assertResult(0)(SpillFramework.stores.deviceStore.numHandles) + + withResource(handle.acquire()) { lease => + assertResult(2)(lease.resource.id) + assert(!lease.resource.isClosed) + } + + assertResult(2)(buildCount) + assertResult(1)(SpillFramework.stores.deviceStore.numHandles) + } + + assertResult(Seq(1, 2))(closedIds.sorted.toSeq) + } + + test("releaseSpilled only closes evicted generations") { + val closedIds = ArrayBuffer[Int]() + var buildCount = 0 + + def buildResource(): TestResource = { + buildCount += 1 + new TestResource(buildCount, closedIds) + } + + withResource( + new SharedRecomputableDeviceHandle(1024L, buildResource(), () => buildResource())) { + handle => + SpillFramework.stores.deviceStore.track(handle) + assertResult(handle.approxSizeInBytes)(handle.spill()) + + withResource(handle.acquire()) { lease => + assertResult(2)(lease.resource.id) + assert(!lease.resource.isClosed) + } + + assertResult(handle.approxSizeInBytes)(handle.spill()) + handle.releaseSpilled() + assertResult(Seq(1, 2))(closedIds.toSeq.sorted) + + withResource(handle.acquire()) { lease => + assertResult(3)(lease.resource.id) + assert(!lease.resource.isClosed) + } + } + + assertResult(Seq(1, 2, 3))(closedIds.sorted.toSeq) + } + + test("concurrent acquires only rebuild once after eviction") { + val closedIds = ArrayBuffer[Int]() + var buildCount = 0 + val rebuildStarted = new CountDownLatch(1) + val allowRebuild = new CountDownLatch(1) + + def buildResource(): TestResource = synchronized { + buildCount += 1 + val id = buildCount + if (id > 1) { + rebuildStarted.countDown() + assert(allowRebuild.await(30, TimeUnit.SECONDS)) + } + new TestResource(id, closedIds) + } + + withResource( + new SharedRecomputableDeviceHandle(1024L, buildResource(), () => buildResource())) { + handle => + SpillFramework.stores.deviceStore.track(handle) + assertResult(handle.approxSizeInBytes)(handle.spill()) + + val pool = Executors.newFixedThreadPool(2) + try { + val futures = (0 until 2).map { _ => + pool.submit(new Callable[Int] { + override def call(): Int = { + withResource(handle.acquire()) { lease => + lease.resource.id + } + } + }) + } + + assert(rebuildStarted.await(30, TimeUnit.SECONDS)) + allowRebuild.countDown() + + val ids = futures.map(_.get(30, TimeUnit.SECONDS)).sorted + assertResult(Seq(2, 2))(ids) + assertResult(2)(buildCount) + } finally { + pool.shutdownNow() + assert(pool.awaitTermination(30, TimeUnit.SECONDS)) + } + } + + assertResult(Seq(1, 2))(closedIds.sorted.toSeq) + } +}