Skip to content
1 change: 1 addition & 0 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def assert_cpu_and_gpu_are_equal_collect_with_capture(func,
_sort_locally(from_cpu, from_gpu)

assert_equal(from_cpu, from_gpu)
return gpu_df

def assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun,
sql,
Expand Down
78 changes: 77 additions & 1 deletion integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from asserts import (assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_row_counts_equal,
assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture,
assert_cpu_and_gpu_are_equal_sql_with_capture, assert_gpu_and_cpu_are_equal_sql)
from conftest import is_emr_runtime
from conftest import is_emr_runtime, spark_jvm
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, disable_ansi_mode
from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime, is_spark_400_or_later, is_spark_411_or_later
Expand Down Expand Up @@ -217,6 +217,82 @@ def do_join(spark):
})


@ignore_order(local=True)
@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn)
def test_broadcast_hash_join_reuse_large_build_aqe(kudo_enabled):
def do_join(spark):
left = spark.range(4096, numPartitions=4).selectExpr(
"CAST(id % 64 AS INT) AS join_key",
"CAST(id AS INT) AS probe_value")
right = spark.range(8192).selectExpr(
"CAST(id % 64 AS INT) AS join_key",
"CAST(id * 3 AS INT) AS build_value")
return left.join(broadcast(right), "join_key", "inner").groupBy("join_key").count()

conf = {
'spark.sql.adaptive.enabled': 'true',
'spark.sql.autoBroadcastJoinThreshold': '-1',
'spark.sql.shuffle.partitions': '4',
'spark.rapids.sql.batchSizeBytes': '1024',
'spark.rapids.sql.join.broadcastHashTable.reuse': 'true',
kudo_enabled_conf_key: kudo_enabled
}
gpu_df = assert_cpu_and_gpu_are_equal_collect_with_capture(
do_join,
exist_classes='GpuBroadcastHashJoinExec',
conf=conf)
jvm = spark_jvm()
cache_builds = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric(
gpu_df._jdf,
'GpuBroadcastHashJoinExec',
'buildSideCacheBuilds')
cache_hits = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric(
gpu_df._jdf,
'GpuBroadcastHashJoinExec',
'buildSideCacheHits')
assert cache_builds == 1
assert cache_hits > 0


@ignore_order(local=True)
@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn)
def test_broadcast_hash_join_reuse_same_broadcast_multi_join_aqe(kudo_enabled):
def do_join(spark):
fact = spark.range(512, numPartitions=4).selectExpr(
"CAST(id % 8 AS INT) AS join_key",
"CAST(id AS INT) AS probe_value")
dim = broadcast(spark.range(32).selectExpr(
"CAST(id % 8 AS INT) AS join_key",
"CAST(id AS INT) AS build_value"))
return fact.join(dim, "join_key", "inner").select("join_key", "probe_value") \
.join(dim, "join_key", "inner").groupBy("join_key").count()

conf = {
'spark.sql.adaptive.enabled': 'true',
'spark.sql.autoBroadcastJoinThreshold': '-1',
'spark.sql.exchange.reuse': 'true',
'spark.sql.shuffle.partitions': '4',
'spark.rapids.sql.batchSizeBytes': '1024',
'spark.rapids.sql.join.broadcastHashTable.reuse': 'true',
kudo_enabled_conf_key: kudo_enabled
}
gpu_df = assert_cpu_and_gpu_are_equal_collect_with_capture(
do_join,
exist_classes='GpuBroadcastHashJoinExec,ReusedExchangeExec',
conf=conf)
jvm = spark_jvm()
cache_builds = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric(
gpu_df._jdf,
'GpuBroadcastHashJoinExec',
'buildSideCacheBuilds')
cache_hits = jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.sumMetric(
gpu_df._jdf,
'GpuBroadcastHashJoinExec',
'buildSideCacheHits')
assert cache_builds == 1
assert cache_hits > 0


# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ object GpuMetric extends Logging {
val BUILD_DATA_SIZE = "buildDataSize"
val BUILD_TIME = "buildTime"
val STREAM_TIME = "streamTime"
val BUILD_SIDE_CACHE_BUILDS = "buildSideCacheBuilds"
val BUILD_SIDE_CACHE_HITS = "buildSideCacheHits"
val NUM_TASKS_FALL_BACKED = "numTasksFallBacked"
val NUM_TASKS_REPARTITIONED = "numTasksRepartitioned"
val NUM_TASKS_SKIPPED_AGG = "numTasksSkippedAgg"
Expand Down Expand Up @@ -173,6 +175,8 @@ object GpuMetric extends Logging {
val DESCRIPTION_BUILD_DATA_SIZE = "build side size"
val DESCRIPTION_BUILD_TIME = "build time"
val DESCRIPTION_STREAM_TIME = "stream time"
val DESCRIPTION_BUILD_SIDE_CACHE_BUILDS = "cached build side builds"
val DESCRIPTION_BUILD_SIDE_CACHE_HITS = "cached build side hits"
val DESCRIPTION_NUM_TASKS_FALL_BACKED = "number of sort fallback tasks"
val DESCRIPTION_NUM_TASKS_REPARTITIONED = "number of tasks repartitioned for agg"
val DESCRIPTION_NUM_TASKS_SKIPPED_AGG = "number of tasks skipped aggregation"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ case class GpuShuffledHashJoinExec(
}
// doJoin will close singleBatch
doJoin(singleBatch, maybeBufferedStreamIter, joinOptions,
numOutputRows, numOutputBatches, opTime, joinTime)
numOutputRows, numOutputBatches, opTime, joinTime, enableBuildSideReuse = false)
case Right(builtBatchIter) =>
// For big joins, when the build data can not fit into a single batch.
val sizeBuildIter = builtBatchIter.map { cb =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -393,6 +393,9 @@ object NvtxRegistry {
val BUILD_JOIN_TABLE: NvtxId = NvtxId("build join table", NvtxColor.GREEN,
"Building hash table for join operation")

val BROADCAST_HASH_TABLE_BUILD: NvtxId = NvtxId("broadcast hash table build",
NvtxColor.GREEN, "Building cuDF hash table for broadcast hash join")

// Window operations
val WINDOW: NvtxId = NvtxId("window", NvtxColor.CYAN,
"Computing window function results")
Expand Down Expand Up @@ -780,6 +783,7 @@ object NvtxRegistry {
register(EXISTENCE_JOIN_SCATTER_MAP)
register(EXISTENCE_JOIN_BATCH)
register(BUILD_JOIN_TABLE)
register(BROADCAST_HASH_TABLE_BUILD)
register(WINDOW)
register(RUNNING_WINDOW)
register(DOUBLE_BATCHED_WINDOW_PRE)
Expand Down
13 changes: 13 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,17 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.checkValues(JoinBuildSideSelection.values.map(_.toString))
.createWithDefault(JoinBuildSideSelection.AUTO.toString)

val BROADCAST_HASH_TABLE_REUSE =
conf("spark.rapids.sql.join.broadcastHashTable.reuse")
.doc("Enable reuse of the broadcast-side hash table for broadcast hash joins. " +
"When enabled, the hash table is built once per broadcast and shared across all " +
"stream batches within a task and across all tasks that consume the same broadcast " +
"on an executor. Reuse pins the physical build side to the broadcast side for the " +
"lifetime of each cached join, overriding the dynamic build-side selection " +
s"heuristic configured by ${JOIN_BUILD_SIDE.key}.")
.booleanConf
.createWithDefault(false)

val LOG_JOIN_CARDINALITY = conf("spark.rapids.sql.join.logCardinality")
.doc("Enable logging of join cardinality statistics to help diagnose performance issues. " +
"When enabled, logs task context, key data types, join condition, row counts, and " +
Expand Down Expand Up @@ -3280,6 +3291,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val joinGathererSizeEstimateThreshold: Double = get(JOIN_GATHERER_SIZE_ESTIMATE_THRESHOLD)

lazy val broadcastHashTableReuse: Boolean = get(BROADCAST_HASH_TABLE_REUSE)

/**
* Get join options based on the current configuration.
* @param targetSize the target batch size in bytes to use for the join
Expand Down
Loading
Loading