From 3e0ba02a46529e3141aff84a356609cacfe9e326 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 29 Apr 2026 09:05:29 +0800 Subject: [PATCH 1/5] ProjectExec split and retry Signed-off-by: Haoyang Li --- .../python/gpu_project_split_retry_test.py | 69 +++++++++ .../python/legacy_parser_oom_repro_test.py | 68 +++++++++ .../com/nvidia/spark/rapids/RapidsConf.scala | 13 ++ .../spark/rapids/basicPhysicalOperators.scala | 131 ++++++++++++++++-- 4 files changed, 271 insertions(+), 10 deletions(-) create mode 100644 integration_tests/src/main/python/gpu_project_split_retry_test.py create mode 100644 integration_tests/src/main/python/legacy_parser_oom_repro_test.py diff --git a/integration_tests/src/main/python/gpu_project_split_retry_test.py b/integration_tests/src/main/python/gpu_project_split_retry_test.py new file mode 100644 index 00000000000..9e4f1eb7932 --- /dev/null +++ b/integration_tests/src/main/python/gpu_project_split_retry_test.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Tests for GpuProjectExec split-retry (Phase 1 of zero-config batching). +# +# `spark.rapids.sql.test.injectRetryOOM` lets us deterministically inject a +# GPU SplitAndRetryOOM mid-projection. With split-retry enabled (default), the +# new path in GpuProjectExec.projectWithSplitRetry must recover by halving the +# input batch by rows and re-running the projection on each half. + +import pyspark.sql.functions as f + +from spark_session import with_gpu_session + + +# Note: we deliberately do NOT have a "@inject_oom split=true" test here. +# `forceSplitAndRetryOOM` is re-fired at every retry-iterator instance the +# task creates, and GpuRangeExec uses its own withRetry(reduceRowsNumberByHalf) +# loop — under split=true that loop divides row count down to its lower +# bound and aborts before ever reaching GpuProjectExec. Real +# split-retry-on-OOM coverage for GpuProjectExec lives in the +# legacy_parser_oom_repro test (which uses real cuDF-scratch pressure +# rather than synthetic injection). + + +def test_project_split_retry_handles_plain_retry_oom(): + """Inject a plain GpuRetryOOM (not SplitAndRetry). The retry framework + should resolve this on its own without invoking the splitter; the new + path must not get in the way.""" + def run(spark): + return (spark.range(0, 10_000, numPartitions=1) + .selectExpr("id + 1 as a", "cast(id as string) as b") + .collect()) + result = with_gpu_session(run, conf={ + "spark.rapids.sql.test.injectRetryOOM": "num_ooms=1,type=GPU,split=false", + "spark.rapids.sql.projectExec.splitRetry.enabled": "true", + }) + assert len(result) == 10_000 + + +def test_project_split_retry_disabled_falls_back_to_legacy_path(): + """When split-retry is disabled by conf, the legacy withRetryNoSplit path + is used. Inject a non-split retry OOM so the legacy path can resolve it, + confirming the conf knob actually wires through.""" + def run(spark): + return (spark.range(0, 10_000, numPartitions=1) + .selectExpr("id + 1 as a") + .collect()) + result = with_gpu_session(run, conf={ + "spark.rapids.sql.test.injectRetryOOM": "num_ooms=1,type=GPU,split=false", + "spark.rapids.sql.projectExec.splitRetry.enabled": "false", + }) + assert len(result) == 10_000 + + +def test_project_with_nondeterministic_runs_normally(): + """Mixed deterministic + non-deterministic projection must take the + legacy withRetryNoSplit path even with split-retry enabled (because the + row-stitching logic requires alignment that row-splitting would break). + This sanity test confirms the dispatch logic does not break ordinary + execution; no OOM injection so the legacy path is exercised in its + happy path.""" + def run(spark): + return (spark.range(0, 10_000, numPartitions=1) + .selectExpr("id + 1 as a", "rand() as r") + .collect()) + result = with_gpu_session(run, conf={ + "spark.rapids.sql.projectExec.splitRetry.enabled": "true", + }) + assert len(result) == 10_000 diff --git a/integration_tests/src/main/python/legacy_parser_oom_repro_test.py b/integration_tests/src/main/python/legacy_parser_oom_repro_test.py new file mode 100644 index 00000000000..627d22f0865 --- /dev/null +++ b/integration_tests/src/main/python/legacy_parser_oom_repro_test.py @@ -0,0 +1,68 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# End-to-end reproducer for the customer GPU OOM in the legacy timestamp +# parser. The interesting GPU work is the projection of +# `from_unixtime(... + unix_timestamp(ts_str, fmt))` under +# `spark.sql.legacy.timeParserPolicy=LEGACY`, which routes to +# GpuToTimestamp.parseStringAsTimestampWithLegacyParserPolicy and into cuDF +# stringReplaceWithBackrefs / matches_re. cuDF allocates a per-row scratch +# buffer (~600 B/row in our measurement) that the spark-rapids estimator +# cannot see, so before split-retry the projection would OOM the pool. +# +# With Phase 1 (GpuProjectExec.projectWithSplitRetry, default-on), the OOM is +# caught at the project layer; the input batch is halved by rows and the +# projection is re-run on each half until each piece fits. +# +# To actually reproduce the OOM-pressure path, the test must be run with: +# - small RMM pool (PYSP_TEST_spark_rapids_memory_gpu_allocSize=8g) +# - reader emitting a single large batch (large parquet row group + +# spark.rapids.sql.reader.chunked=false + +# spark.rapids.sql.reader.batchSizeBytes=3g) +# - TEST_PARALLEL=1 so the pool isn't sliced across xdist workers +# Without those, the GPU reader chunks the input and no single batch is large +# enough to trigger the cuDF-scratch OOM. + +import pytest +import pyspark.sql.functions as f + +from spark_session import with_cpu_session, with_gpu_session + + +@pytest.mark.parametrize("rows", [100_000_000, 200_000_000], ids=["100M", "200M"]) +def test_legacy_parser_oom_repro(spark_tmp_path, rows): + data_path = spark_tmp_path + '/LEGACY_PARSER_OOM' + + # Force a single huge parquet row group so the GPU reader sees a single + # big batch when chunked-reader is disabled at run time. + write_conf = { + 'spark.sql.parquet.block.size': str(4 * 1024 * 1024 * 1024), + 'parquet.block.size': str(4 * 1024 * 1024 * 1024), + } + # The CASE keeps Catalyst from folding ts_str into table metadata; in + # practice every row holds the same valid timestamp string. + with_cpu_session(lambda spark: spark.range(0, rows, numPartitions=1) + .selectExpr( + "id as offset_long", + "case when id >= 0 then '2024-06-15 12:34:56' " + "else '2024-06-15 12:34:55' end as ts_str") + .write.mode('overwrite').parquet(data_path), conf=write_conf) + + def run(spark): + # Aggregate to a single scalar so the projected `t` column is fully + # materialized on the GPU but we don't ship a billion rows back. + spark.read.parquet(data_path).selectExpr( + "from_unixtime(offset_long + " + "unix_timestamp(ts_str, 'yyyy-MM-dd HH:mm:ss')) as t" + ).agg(f.sum(f.length('t'))).collect() + + conf = { + "spark.sql.legacy.timeParserPolicy": "LEGACY", + # Required to let UnixTimestamp/FromUnixTime with LEGACY format reach + # the GPU path (parseStringAsTimestampWithLegacyParserPolicy). This + # is the knob the customer has turned on in production. + "spark.rapids.sql.incompatibleDateFormats.enabled": "true", + # Phase 1 split-retry is default-on; left explicit here to make the + # test's expectation visible and to allow flipping it for debugging. + "spark.rapids.sql.projectExec.splitRetry.enabled": "true", + } + with_gpu_session(run, conf=conf) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 5cfc9cb09bf..97f31cd2c5f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2737,6 +2737,17 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .longConf .createOptional + val PROJECT_SPLIT_RETRY_ENABLED = conf("spark.rapids.sql.projectExec.splitRetry.enabled") + .doc("When true, GpuProjectExec uses split-and-retry on GPU OOM for purely " + + "deterministic projections: the input batch is halved by rows and the " + + "projection is re-run on each half. Mixed deterministic + non-deterministic " + + "projections fall back to the existing withRetryNoSplit path because the " + + "non-deterministic side is computed once on the full batch and stitched " + + "row-by-row to the deterministic side, which row-splitting would break. " + + "Disable this to revert to the prior behavior.") + .booleanConf + .createWithDefault(true) + val TEST_IO_ENCRYPTION = conf("spark.rapids.test.io.encryption") .doc("Only for tests: verify for IO encryption") .internal() @@ -3934,6 +3945,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val chunkedPackBounceBufferSize: Long = get(CHUNKED_PACK_BOUNCE_BUFFER_SIZE) + lazy val isProjectSplitRetryEnabled: Boolean = get(PROJECT_SPLIT_RETRY_ENABLED) + lazy val chunkedPackBounceBufferCount: Int = get(CHUNKED_PACK_BOUNCE_BUFFER_COUNT) lazy val spillToDiskBounceBufferSize: Long = get(SPILL_TO_DISK_BOUNCE_BUFFER_SIZE) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index a97f830fe3e..cc37911d3a2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, RangePartitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SampleExec, SparkPlan} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuCreateArray, GpuCreateMap, GpuCreateNamedStruct, GpuPartitionwiseSampledRDD, GpuPoissonSampler} import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ @@ -162,6 +163,21 @@ object GpuProjectExec { def projectWithRetrySingleBatch(sb: SpillableColumnarBatch, boundExprs: Seq[Expression]): ColumnarBatch = { + // For purely deterministic projections, use split-retry: on GPU OOM, halve + // the input batch by rows and re-run on each half. This recovers from + // cuDF-internal scratch allocations that the pre-split estimator cannot + // see (e.g. regex / string-replace working memory). + // + // Mixed deterministic + non-deterministic projections fall through to the + // existing withRetryNoSplit path: the non-deterministic side is computed + // once on the full input batch and stitched row-by-row to the deterministic + // side, and row-splitting either side would break that alignment. + if (new RapidsConf(SQLConf.get).isProjectSplitRetryEnabled && + boundExprs.forall(_.deterministic)) { + val retryables = GpuExpressionsUtils.collectRetryables(boundExprs) + return runWithSplitRetry(sb, retryables, project(_, boundExprs)) + } + // First off we want to find/run all of the expressions that are not retryable, // These cannot be retried. val (retryableExprs, notRetryableExprs) = boundExprs.partition( @@ -212,6 +228,80 @@ object GpuProjectExec { } } } + + /** + * Run a deterministic projection with row-split retry. On GPU OOM the retry + * framework calls splitSpillableInHalfByRows to halve the input batch and + * re-runs the projection on each half. The resulting sub-batches are + * concatenated back into a single output batch to preserve the single-batch + * contract of projectAndCloseWithRetrySingleBatch. + * + * Caller must ensure the projection driven by `runProject` is purely + * deterministic — non-deterministic expressions cannot be safely + * re-evaluated on row-split sub-batches. + * + * `runProject` receives a (non-spillable) ColumnarBatch and returns the + * projected ColumnarBatch. It must not close its input (the framework will). + * + * Takes ownership of `sb`: it is closed by the retry iterator when drained. + * If the caller does not want to surrender ownership, it must increment the + * ref count before calling. + */ + private[rapids] def runWithSplitRetry( + sb: SpillableColumnarBatch, + retryables: Seq[Retryable], + runProject: ColumnarBatch => ColumnarBatch): ColumnarBatch = { + retryables.foreach(_.checkpoint()) + val resultIter = withRetry(sb, splitSpillableInHalfByRows) { spillable => + withResource(spillable.getColumnarBatch()) { cb => + withRestoreOnRetry(retryables) { + runProject(cb) + } + } + } + // Drain the retry iterator. Each piece is an independently-projected + // sub-batch. If draining itself throws (e.g. a later split also OOMs and + // retry is exhausted), close any pieces collected so far before + // propagating. + val pieces = ArrayBuffer[ColumnarBatch]() + closeOnExcept(pieces) { _ => + while (resultIter.hasNext) { + pieces += resultIter.next() + } + } + if (pieces.length == 1) { + pieces.head + } else { + concatColumnarBatches(pieces.toArray) + } + } + + /** + * Concatenate a non-empty array of ColumnarBatches into a single ColumnarBatch. + * Closes all input batches on success; on failure, the input batches are also + * closed via closeOnExcept. The returned batch's device buffers are + * independent of the inputs (Table.concatenate copies). + * + * Note: if the concatenation itself OOMs, it will be caught by whatever outer + * retry layer surrounds the caller. We don't try to recover here because by + * definition the retry framework already split the input as far as it could. + */ + private def concatColumnarBatches(pieces: Array[ColumnarBatch]): ColumnarBatch = { + require(pieces.nonEmpty, "concatColumnarBatches requires at least one piece") + closeOnExcept(pieces) { _ => + val outputTypes: Array[DataType] = (0 until pieces.head.numCols()).map { i => + pieces.head.column(i).asInstanceOf[GpuColumnVector].dataType() + }.toArray + val result = withResource(pieces.safeMap(GpuColumnVector.from)) { tables => + withResource(Table.concatenate(tables: _*)) { concatenated => + GpuColumnVector.from(concatenated, outputTypes) + } + } + // Result holds independent device buffers; release the input pieces. + pieces.foreach(_.close()) + result + } + } } /** @@ -947,6 +1037,17 @@ case class GpuProjectAstExec( } } + /** + * Are all expressions across all tiers deterministic. This is a stricter + * check than [[areAllRetryable]] — a Retryable but non-deterministic + * expression (e.g. GpuRand) is retryable but cannot be safely re-evaluated + * on a row-split sub-batch. Used by the split-retry path to gate row + * splitting. + */ + lazy val areAllDeterministic = exprTiers.forall { tier => + tier.forall(_.deterministic) + } + lazy val retryables: Seq[Retryable] = exprTiers.flatMap(GpuExpressionsUtils.collectRetryables) lazy val outputExprs = exprTiers.last.toArray @@ -997,17 +1098,27 @@ case class GpuProjectAstExec( // If all of the expressions are retryable we can just run everything and retry it // at the top level. If some things are not retryable we need to split them up and // do the processing in a way that makes it so retries are more likely to succeed. - val sbToClose = if (closeInputBatch) { - Some(sb) + if (areAllDeterministic && new RapidsConf(SQLConf.get).isProjectSplitRetryEnabled) { + // Split-retry path: on GPU OOM, halve the input batch by rows and + // re-run the projection on each half. runWithSplitRetry takes + // ownership of the SpillableColumnarBatch and closes it; if the + // caller asked us not to close `sb`, increment the ref count to + // compensate. + val sbForRetry = if (closeInputBatch) sb else sb.incRefCount() + GpuProjectExec.runWithSplitRetry(sbForRetry, retryables, project(_)) } else { - None - } - withResource(sbToClose) { _ => - retryables.foreach(_.checkpoint()) - RmmRapidsRetryIterator.withRetryNoSplit { - withResource(sb.getColumnarBatch()) { cb => - withRestoreOnRetry(retryables) { - project(cb) + val sbToClose = if (closeInputBatch) { + Some(sb) + } else { + None + } + withResource(sbToClose) { _ => + retryables.foreach(_.checkpoint()) + RmmRapidsRetryIterator.withRetryNoSplit { + withResource(sb.getColumnarBatch()) { cb => + withRestoreOnRetry(retryables) { + project(cb) + } } } } From 68ca3fef754ba70f0cd19494c7cff4a82f80af79 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 29 Apr 2026 14:03:11 +0800 Subject: [PATCH 2/5] remove debug it Signed-off-by: Haoyang Li --- .../python/legacy_parser_oom_repro_test.py | 68 ------------------- 1 file changed, 68 deletions(-) delete mode 100644 integration_tests/src/main/python/legacy_parser_oom_repro_test.py diff --git a/integration_tests/src/main/python/legacy_parser_oom_repro_test.py b/integration_tests/src/main/python/legacy_parser_oom_repro_test.py deleted file mode 100644 index 627d22f0865..00000000000 --- a/integration_tests/src/main/python/legacy_parser_oom_repro_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. -# -# End-to-end reproducer for the customer GPU OOM in the legacy timestamp -# parser. The interesting GPU work is the projection of -# `from_unixtime(... + unix_timestamp(ts_str, fmt))` under -# `spark.sql.legacy.timeParserPolicy=LEGACY`, which routes to -# GpuToTimestamp.parseStringAsTimestampWithLegacyParserPolicy and into cuDF -# stringReplaceWithBackrefs / matches_re. cuDF allocates a per-row scratch -# buffer (~600 B/row in our measurement) that the spark-rapids estimator -# cannot see, so before split-retry the projection would OOM the pool. -# -# With Phase 1 (GpuProjectExec.projectWithSplitRetry, default-on), the OOM is -# caught at the project layer; the input batch is halved by rows and the -# projection is re-run on each half until each piece fits. -# -# To actually reproduce the OOM-pressure path, the test must be run with: -# - small RMM pool (PYSP_TEST_spark_rapids_memory_gpu_allocSize=8g) -# - reader emitting a single large batch (large parquet row group + -# spark.rapids.sql.reader.chunked=false + -# spark.rapids.sql.reader.batchSizeBytes=3g) -# - TEST_PARALLEL=1 so the pool isn't sliced across xdist workers -# Without those, the GPU reader chunks the input and no single batch is large -# enough to trigger the cuDF-scratch OOM. - -import pytest -import pyspark.sql.functions as f - -from spark_session import with_cpu_session, with_gpu_session - - -@pytest.mark.parametrize("rows", [100_000_000, 200_000_000], ids=["100M", "200M"]) -def test_legacy_parser_oom_repro(spark_tmp_path, rows): - data_path = spark_tmp_path + '/LEGACY_PARSER_OOM' - - # Force a single huge parquet row group so the GPU reader sees a single - # big batch when chunked-reader is disabled at run time. - write_conf = { - 'spark.sql.parquet.block.size': str(4 * 1024 * 1024 * 1024), - 'parquet.block.size': str(4 * 1024 * 1024 * 1024), - } - # The CASE keeps Catalyst from folding ts_str into table metadata; in - # practice every row holds the same valid timestamp string. - with_cpu_session(lambda spark: spark.range(0, rows, numPartitions=1) - .selectExpr( - "id as offset_long", - "case when id >= 0 then '2024-06-15 12:34:56' " - "else '2024-06-15 12:34:55' end as ts_str") - .write.mode('overwrite').parquet(data_path), conf=write_conf) - - def run(spark): - # Aggregate to a single scalar so the projected `t` column is fully - # materialized on the GPU but we don't ship a billion rows back. - spark.read.parquet(data_path).selectExpr( - "from_unixtime(offset_long + " - "unix_timestamp(ts_str, 'yyyy-MM-dd HH:mm:ss')) as t" - ).agg(f.sum(f.length('t'))).collect() - - conf = { - "spark.sql.legacy.timeParserPolicy": "LEGACY", - # Required to let UnixTimestamp/FromUnixTime with LEGACY format reach - # the GPU path (parseStringAsTimestampWithLegacyParserPolicy). This - # is the knob the customer has turned on in production. - "spark.rapids.sql.incompatibleDateFormats.enabled": "true", - # Phase 1 split-retry is default-on; left explicit here to make the - # test's expectation visible and to allow flipping it for debugging. - "spark.rapids.sql.projectExec.splitRetry.enabled": "true", - } - with_gpu_session(run, conf=conf) From b55ab404b685a089dff48906cba312cc8ed66532 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 29 Apr 2026 16:15:02 +0800 Subject: [PATCH 3/5] refine Signed-off-by: Haoyang Li --- .../python/gpu_project_split_retry_test.py | 69 ------- .../com/nvidia/spark/rapids/RapidsConf.scala | 1 + .../spark/rapids/ProjectSplitRetrySuite.scala | 189 ++++++++++++++++++ 3 files changed, 190 insertions(+), 69 deletions(-) delete mode 100644 integration_tests/src/main/python/gpu_project_split_retry_test.py create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala diff --git a/integration_tests/src/main/python/gpu_project_split_retry_test.py b/integration_tests/src/main/python/gpu_project_split_retry_test.py deleted file mode 100644 index 9e4f1eb7932..00000000000 --- a/integration_tests/src/main/python/gpu_project_split_retry_test.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. -# -# Tests for GpuProjectExec split-retry (Phase 1 of zero-config batching). -# -# `spark.rapids.sql.test.injectRetryOOM` lets us deterministically inject a -# GPU SplitAndRetryOOM mid-projection. With split-retry enabled (default), the -# new path in GpuProjectExec.projectWithSplitRetry must recover by halving the -# input batch by rows and re-running the projection on each half. - -import pyspark.sql.functions as f - -from spark_session import with_gpu_session - - -# Note: we deliberately do NOT have a "@inject_oom split=true" test here. -# `forceSplitAndRetryOOM` is re-fired at every retry-iterator instance the -# task creates, and GpuRangeExec uses its own withRetry(reduceRowsNumberByHalf) -# loop — under split=true that loop divides row count down to its lower -# bound and aborts before ever reaching GpuProjectExec. Real -# split-retry-on-OOM coverage for GpuProjectExec lives in the -# legacy_parser_oom_repro test (which uses real cuDF-scratch pressure -# rather than synthetic injection). - - -def test_project_split_retry_handles_plain_retry_oom(): - """Inject a plain GpuRetryOOM (not SplitAndRetry). The retry framework - should resolve this on its own without invoking the splitter; the new - path must not get in the way.""" - def run(spark): - return (spark.range(0, 10_000, numPartitions=1) - .selectExpr("id + 1 as a", "cast(id as string) as b") - .collect()) - result = with_gpu_session(run, conf={ - "spark.rapids.sql.test.injectRetryOOM": "num_ooms=1,type=GPU,split=false", - "spark.rapids.sql.projectExec.splitRetry.enabled": "true", - }) - assert len(result) == 10_000 - - -def test_project_split_retry_disabled_falls_back_to_legacy_path(): - """When split-retry is disabled by conf, the legacy withRetryNoSplit path - is used. Inject a non-split retry OOM so the legacy path can resolve it, - confirming the conf knob actually wires through.""" - def run(spark): - return (spark.range(0, 10_000, numPartitions=1) - .selectExpr("id + 1 as a") - .collect()) - result = with_gpu_session(run, conf={ - "spark.rapids.sql.test.injectRetryOOM": "num_ooms=1,type=GPU,split=false", - "spark.rapids.sql.projectExec.splitRetry.enabled": "false", - }) - assert len(result) == 10_000 - - -def test_project_with_nondeterministic_runs_normally(): - """Mixed deterministic + non-deterministic projection must take the - legacy withRetryNoSplit path even with split-retry enabled (because the - row-stitching logic requires alignment that row-splitting would break). - This sanity test confirms the dispatch logic does not break ordinary - execution; no OOM injection so the legacy path is exercised in its - happy path.""" - def run(spark): - return (spark.range(0, 10_000, numPartitions=1) - .selectExpr("id + 1 as a", "rand() as r") - .collect()) - result = with_gpu_session(run, conf={ - "spark.rapids.sql.projectExec.splitRetry.enabled": "true", - }) - assert len(result) == 10_000 diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 97f31cd2c5f..5ae8fa6614f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2745,6 +2745,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. "non-deterministic side is computed once on the full batch and stitched " + "row-by-row to the deterministic side, which row-splitting would break. " + "Disable this to revert to the prior behavior.") + .internal() .booleanConf .createWithDefault(true) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala new file mode 100644 index 00000000000..99ea7a5fdb0 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import ai.rapids.cudf.ColumnVector +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq +import com.nvidia.spark.rapids.jni.{GpuSplitAndRetryOOM, RmmSpark} + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId, NamedExpression} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.GpuAdd +import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.vectorized.ColumnarBatch + +class ProjectSplitRetrySuite extends RmmSparkRetrySuiteBase { + private val NUM_ROWS = 500 + private val RAND_SEED = 10 + private val intAttr = AttributeReference("int", IntegerType)(ExprId(10)) + private val batchAttrs = Seq(intAttr) + + private def buildBatch(): ColumnarBatch = { + val ints = 0 until NUM_ROWS + new ColumnarBatch( + Array(GpuColumnVector.from(ColumnVector.fromInts(ints: _*), IntegerType)), + ints.length) + } + + private def newSpillable(): SpillableColumnarBatch = + SpillableColumnarBatch(buildBatch(), SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + + // GpuAdd(int, 1) — pure, deterministic, retryable. + private def addOneExprs(): Seq[GpuExpression] = Seq( + GpuAlias(GpuAdd( + GpuBoundReference(0, IntegerType, true)(NamedExpression.newExprId, "int"), + GpuLiteral(1, IntegerType), + failOnError = false)(), + "plus_one")()) + + private def collectInts(cb: ColumnarBatch, col: Int): Array[Int] = { + val gcv = cb.column(col).asInstanceOf[GpuColumnVector] + withResource(gcv.copyToHost()) { hcv => + (0 until cb.numRows()).map(hcv.getInt).toArray + } + } + + private def collectDoubles(cb: ColumnarBatch, col: Int): Array[Double] = { + val gcv = cb.column(col).asInstanceOf[GpuColumnVector] + withResource(gcv.copyToHost()) { hcv => + (0 until cb.numRows()).map(hcv.getDouble).toArray + } + } + + // Helper: build the SpillableColumnarBatch BEFORE injecting the OOM, so + // that the alloc inside ColumnVector.fromInts doesn't accidentally absorb + // the injection. Only the projection itself should trip the OOM. + private def withInjectedOOM[T](inject: => Unit)(body: SpillableColumnarBatch => T): T = { + val sb = newSpillable() + inject + body(sb) + } + + test("split-retry produces same output as a single-batch projection") { + val out = withInjectedOOM { + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + } { sb => + GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, addOneExprs()) + } + withResource(out) { cb => + assertResult(NUM_ROWS)(cb.numRows()) + assertResult(1)(cb.numCols()) + val got = collectInts(cb, 0) + (0 until NUM_ROWS).foreach { i => + assertResult(i + 1)(got(i)) + } + } + assert(RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1) > 0, + "expected at least one SplitAndRetryOOM to have been observed") + } + + test("conf=false routes split-retry OOM to legacy path which fails") { + val sqlConf = new SQLConf() + sqlConf.setConfString(RapidsConf.PROJECT_SPLIT_RETRY_ENABLED.key, "false") + SQLConf.withExistingConf(sqlConf) { + val sb = newSpillable() + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + assertThrows[GpuSplitAndRetryOOM] { + GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, addOneExprs()).close() + } + } + } + + test("tiered project split-retry produces correct output") { + val tier = GpuBindReferences.bindGpuReferencesTiered( + addOneExprs(), batchAttrs, new SQLConf(), Map.empty) + assert(tier.areAllRetryable && tier.areAllDeterministic) + val out = withInjectedOOM { + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + } { sb => + tier.projectAndCloseWithRetrySingleBatch(sb) + } + withResource(out) { cb => + assertResult(NUM_ROWS)(cb.numRows()) + val got = collectInts(cb, 0) + (0 until NUM_ROWS).foreach { i => + assertResult(i + 1)(got(i)) + } + } + assert(RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1) > 0) + } + + // A non-deterministic projection (containing GpuRand) must NOT take the + // split path even when the conf is on, because row-splitting would + // change rand state across the halves and break the row-aligned stitch + // between deterministic and rand columns. forceRetryOOM (plain, not + // split) verifies the legacy withRetryNoSplit path is selected and + // checkpoint/restore reproduces the rand sequence on retry. + test("mixed deterministic + GpuRand falls back to legacy retry path") { + def projection(): Seq[GpuExpression] = Seq( + GpuAlias(GpuAdd( + GpuBoundReference(0, IntegerType, true)(NamedExpression.newExprId, "int"), + GpuLiteral(1, IntegerType), + failOnError = false)(), "plus_one")(), + GpuAlias(GpuRand(GpuLiteral(RAND_SEED, IntegerType), false), "rnd")()) + + val batches = Seq(true, false).safeMap { forceRetry => + val tier = GpuBindReferences.bindGpuReferencesTiered( + projection(), batchAttrs, new SQLConf(), Map.empty) + assert(tier.areAllRetryable && !tier.areAllDeterministic) + val sb = newSpillable() + if (forceRetry) { + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + } + tier.projectAndCloseWithRetrySingleBatch(sb) + } + withResource(batches) { case Seq(retried, ref) => + assertResult(ref.numRows())(retried.numRows()) + assertResult(ref.numCols())(retried.numCols()) + val refPlus = collectInts(ref, 0) + val retPlus = collectInts(retried, 0) + val refRand = collectDoubles(ref, 1) + val retRand = collectDoubles(retried, 1) + (0 until NUM_ROWS).foreach { i => + assertResult(refPlus(i))(retPlus(i)) + assertResult(refRand(i))(retRand(i)) + } + } + } + + // A plain GpuRetryOOM under the new path is resolved before the + // splitter is invoked, so the result comes back as a single piece — + // exercising the `pieces.length == 1` early-return in runWithSplitRetry. + test("plain GpuRetryOOM under split-retry path returns a single piece") { + val out = withInjectedOOM { + RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1, + RmmSpark.OomInjectionType.GPU.ordinal, 0) + } { sb => + GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, addOneExprs()) + } + withResource(out) { cb => + assertResult(NUM_ROWS)(cb.numRows()) + val got = collectInts(cb, 0) + (0 until NUM_ROWS).foreach { i => + assertResult(i + 1)(got(i)) + } + } + assertResult(0)(RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1)) + assert(RmmSpark.getAndResetNumRetryThrow(/*taskId*/ 1) > 0) + } +} From a5f6dbba9cb8916d5ef6d933bd080a77bf835f53 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 5 May 2026 11:46:59 +0800 Subject: [PATCH 4/5] simplify Signed-off-by: Haoyang Li --- .../spark/rapids/basicPhysicalOperators.scala | 42 +++---------------- .../spark/rapids/ProjectSplitRetrySuite.scala | 8 ++++ 2 files changed, 13 insertions(+), 37 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index cc37911d3a2..f33a3569697 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -232,9 +232,9 @@ object GpuProjectExec { /** * Run a deterministic projection with row-split retry. On GPU OOM the retry * framework calls splitSpillableInHalfByRows to halve the input batch and - * re-runs the projection on each half. The resulting sub-batches are - * concatenated back into a single output batch to preserve the single-batch - * contract of projectAndCloseWithRetrySingleBatch. + * re-runs the projection on each half; sub-batches are concatenated back + * into a single output batch to preserve the single-batch contract of + * projectAndCloseWithRetrySingleBatch. * * Caller must ensure the projection driven by `runProject` is purely * deterministic — non-deterministic expressions cannot be safely @@ -259,47 +259,15 @@ object GpuProjectExec { } } } - // Drain the retry iterator. Each piece is an independently-projected - // sub-batch. If draining itself throws (e.g. a later split also OOMs and - // retry is exhausted), close any pieces collected so far before - // propagating. val pieces = ArrayBuffer[ColumnarBatch]() closeOnExcept(pieces) { _ => while (resultIter.hasNext) { pieces += resultIter.next() } - } - if (pieces.length == 1) { - pieces.head - } else { - concatColumnarBatches(pieces.toArray) - } - } - - /** - * Concatenate a non-empty array of ColumnarBatches into a single ColumnarBatch. - * Closes all input batches on success; on failure, the input batches are also - * closed via closeOnExcept. The returned batch's device buffers are - * independent of the inputs (Table.concatenate copies). - * - * Note: if the concatenation itself OOMs, it will be caught by whatever outer - * retry layer surrounds the caller. We don't try to recover here because by - * definition the retry framework already split the input as far as it could. - */ - private def concatColumnarBatches(pieces: Array[ColumnarBatch]): ColumnarBatch = { - require(pieces.nonEmpty, "concatColumnarBatches requires at least one piece") - closeOnExcept(pieces) { _ => - val outputTypes: Array[DataType] = (0 until pieces.head.numCols()).map { i => + val outputTypes = (0 until pieces.head.numCols()).map { i => pieces.head.column(i).asInstanceOf[GpuColumnVector].dataType() }.toArray - val result = withResource(pieces.safeMap(GpuColumnVector.from)) { tables => - withResource(Table.concatenate(tables: _*)) { concatenated => - GpuColumnVector.from(concatenated, outputTypes) - } - } - // Result holds independent device buffers; release the input pieces. - pieces.foreach(_.close()) - result + ConcatAndConsumeAll.buildNonEmptyBatchFromTypes(pieces.toArray, outputTypes) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala index 99ea7a5fdb0..cc5317880d2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala @@ -34,6 +34,14 @@ class ProjectSplitRetrySuite extends RmmSparkRetrySuiteBase { private val intAttr = AttributeReference("int", IntegerType)(ExprId(10)) private val batchAttrs = Seq(intAttr) + // Reset retry counters so a leaked count from one test cannot mask a + // missed injection in the next. + override def afterEach(): Unit = { + RmmSpark.getAndResetNumRetryThrow(/*taskId*/ 1) + RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1) + super.afterEach() + } + private def buildBatch(): ColumnarBatch = { val ints = 0 until NUM_ROWS new ColumnarBatch( From 6eb27e93aeba17eebcf320cbf1060de93be15dce Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 5 May 2026 15:10:49 +0800 Subject: [PATCH 5/5] address coemments Signed-off-by: Haoyang Li --- .../com/nvidia/spark/rapids/basicPhysicalOperators.scala | 5 ++++- .../com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index f33a3569697..744e6a01adf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -175,7 +175,10 @@ object GpuProjectExec { if (new RapidsConf(SQLConf.get).isProjectSplitRetryEnabled && boundExprs.forall(_.deterministic)) { val retryables = GpuExpressionsUtils.collectRetryables(boundExprs) - return runWithSplitRetry(sb, retryables, project(_, boundExprs)) + // runWithSplitRetry takes ownership of the SpillableColumnarBatch; bump + // the ref count so the caller (which is responsible for closing `sb`, + // per this method's contract) doesn't double-close it. + return runWithSplitRetry(sb.incRefCount(), retryables, project(_, boundExprs)) } // First off we want to find/run all of the expressions that are not retryable, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala index cc5317880d2..ee07fe52e55 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ProjectSplitRetrySuite.scala @@ -174,9 +174,9 @@ class ProjectSplitRetrySuite extends RmmSparkRetrySuiteBase { } } - // A plain GpuRetryOOM under the new path is resolved before the - // splitter is invoked, so the result comes back as a single piece — - // exercising the `pieces.length == 1` early-return in runWithSplitRetry. + // A plain GpuRetryOOM under the new path is resolved before the splitter + // is invoked, so the result comes back as a single piece — exercising the + // single-piece path through ConcatAndConsumeAll.buildNonEmptyBatchFromTypes. test("plain GpuRetryOOM under split-retry path returns a single piece") { val out = withInjectedOOM { RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,