Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2737,6 +2737,18 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
.longConf
.createOptional

val PROJECT_SPLIT_RETRY_ENABLED = conf("spark.rapids.sql.projectExec.splitRetry.enabled")
.doc("When true, GpuProjectExec uses split-and-retry on GPU OOM for purely " +
"deterministic projections: the input batch is halved by rows and the " +
"projection is re-run on each half. Mixed deterministic + non-deterministic " +
"projections fall back to the existing withRetryNoSplit path because the " +
"non-deterministic side is computed once on the full batch and stitched " +
"row-by-row to the deterministic side, which row-splitting would break. " +
"Disable this to revert to the prior behavior.")
.internal()
.booleanConf
.createWithDefault(true)

val TEST_IO_ENCRYPTION = conf("spark.rapids.test.io.encryption")
.doc("Only for tests: verify for IO encryption")
.internal()
Expand Down Expand Up @@ -3934,6 +3946,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val chunkedPackBounceBufferSize: Long = get(CHUNKED_PACK_BOUNCE_BUFFER_SIZE)

lazy val isProjectSplitRetryEnabled: Boolean = get(PROJECT_SPLIT_RETRY_ENABLED)

lazy val chunkedPackBounceBufferCount: Int = get(CHUNKED_PACK_BOUNCE_BUFFER_COUNT)

lazy val spillToDiskBounceBufferSize: Long = get(SPILL_TO_DISK_BOUNCE_BUFFER_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, RangePartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SampleExec, SparkPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuCreateArray, GpuCreateMap, GpuCreateNamedStruct, GpuPartitionwiseSampledRDD, GpuPoissonSampler}
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -162,6 +163,24 @@ object GpuProjectExec {
def projectWithRetrySingleBatch(sb: SpillableColumnarBatch,
boundExprs: Seq[Expression]): ColumnarBatch = {

// For purely deterministic projections, use split-retry: on GPU OOM, halve
// the input batch by rows and re-run on each half. This recovers from
// cuDF-internal scratch allocations that the pre-split estimator cannot
// see (e.g. regex / string-replace working memory).
//
// Mixed deterministic + non-deterministic projections fall through to the
// existing withRetryNoSplit path: the non-deterministic side is computed
// once on the full input batch and stitched row-by-row to the deterministic
// side, and row-splitting either side would break that alignment.
if (new RapidsConf(SQLConf.get).isProjectSplitRetryEnabled &&
boundExprs.forall(_.deterministic)) {
val retryables = GpuExpressionsUtils.collectRetryables(boundExprs)
// runWithSplitRetry takes ownership of the SpillableColumnarBatch; bump
// the ref count so the caller (which is responsible for closing `sb`,
// per this method's contract) doesn't double-close it.
return runWithSplitRetry(sb.incRefCount(), retryables, project(_, boundExprs))
}

// First off we want to find/run all of the expressions that are not retryable,
// These cannot be retried.
val (retryableExprs, notRetryableExprs) = boundExprs.partition(
Expand Down Expand Up @@ -212,6 +231,48 @@ object GpuProjectExec {
}
}
}

/**
* Run a deterministic projection with row-split retry. On GPU OOM the retry
* framework calls splitSpillableInHalfByRows to halve the input batch and
* re-runs the projection on each half; sub-batches are concatenated back
* into a single output batch to preserve the single-batch contract of
* projectAndCloseWithRetrySingleBatch.
*
* Caller must ensure the projection driven by `runProject` is purely
* deterministic — non-deterministic expressions cannot be safely
* re-evaluated on row-split sub-batches.
*
* `runProject` receives a (non-spillable) ColumnarBatch and returns the
* projected ColumnarBatch. It must not close its input (the framework will).
*
* Takes ownership of `sb`: it is closed by the retry iterator when drained.
* If the caller does not want to surrender ownership, it must increment the
* ref count before calling.
*/
private[rapids] def runWithSplitRetry(
sb: SpillableColumnarBatch,
retryables: Seq[Retryable],
runProject: ColumnarBatch => ColumnarBatch): ColumnarBatch = {
retryables.foreach(_.checkpoint())
val resultIter = withRetry(sb, splitSpillableInHalfByRows) { spillable =>
withResource(spillable.getColumnarBatch()) { cb =>
withRestoreOnRetry(retryables) {
runProject(cb)
}
}
}
val pieces = ArrayBuffer[ColumnarBatch]()
closeOnExcept(pieces) { _ =>
while (resultIter.hasNext) {
pieces += resultIter.next()
}
val outputTypes = (0 until pieces.head.numCols()).map { i =>
pieces.head.column(i).asInstanceOf[GpuColumnVector].dataType()
}.toArray
ConcatAndConsumeAll.buildNonEmptyBatchFromTypes(pieces.toArray, outputTypes)
}
}
Comment on lines +265 to +275
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Double-close of pieces when buildNonEmptyBatchFromTypes throws in the multi-batch path. ConcatAndConsumeAll.buildNonEmptyBatchFromTypes closes all input batches in its finally block (arrayOfBatches.foreach(_.close())). If it throws (e.g., OOM during Table.concatenate), both that finally and the closeOnExcept(pieces) handler call close on the same GpuColumnVector objects, driving the cuDF native ref-count negative. Narrow the closeOnExcept scope to cover only the collection loop and let buildNonEmptyBatchFromTypes take exclusive ownership for the concat step.

Suggested change
val pieces = ArrayBuffer[ColumnarBatch]()
closeOnExcept(pieces) { _ =>
while (resultIter.hasNext) {
pieces += resultIter.next()
}
val outputTypes = (0 until pieces.head.numCols()).map { i =>
pieces.head.column(i).asInstanceOf[GpuColumnVector].dataType()
}.toArray
ConcatAndConsumeAll.buildNonEmptyBatchFromTypes(pieces.toArray, outputTypes)
}
}
val pieces = ArrayBuffer[ColumnarBatch]()
closeOnExcept(pieces) { _ =>
while (resultIter.hasNext) {
pieces += resultIter.next()
}
}
val outputTypes = (0 until pieces.head.numCols()).map { i =>
pieces.head.column(i).asInstanceOf[GpuColumnVector].dataType()
}.toArray
// Transfer ownership: buildNonEmptyBatchFromTypes closes the batches itself
// (in its finally block). Do not wrap in closeOnExcept here to avoid a
// double-close if Table.concatenate throws.
ConcatAndConsumeAll.buildNonEmptyBatchFromTypes(pieces.toArray, outputTypes)
}

}

/**
Expand Down Expand Up @@ -947,6 +1008,17 @@ case class GpuProjectAstExec(
}
}

/**
* Are all expressions across all tiers deterministic. This is a stricter
* check than [[areAllRetryable]] — a Retryable but non-deterministic
* expression (e.g. GpuRand) is retryable but cannot be safely re-evaluated
* on a row-split sub-batch. Used by the split-retry path to gate row
* splitting.
*/
lazy val areAllDeterministic = exprTiers.forall { tier =>
tier.forall(_.deterministic)
}

lazy val retryables: Seq[Retryable] = exprTiers.flatMap(GpuExpressionsUtils.collectRetryables)

lazy val outputExprs = exprTiers.last.toArray
Expand Down Expand Up @@ -997,17 +1069,27 @@ case class GpuProjectAstExec(
// If all of the expressions are retryable we can just run everything and retry it
// at the top level. If some things are not retryable we need to split them up and
// do the processing in a way that makes it so retries are more likely to succeed.
val sbToClose = if (closeInputBatch) {
Some(sb)
if (areAllDeterministic && new RapidsConf(SQLConf.get).isProjectSplitRetryEnabled) {
// Split-retry path: on GPU OOM, halve the input batch by rows and
// re-run the projection on each half. runWithSplitRetry takes
// ownership of the SpillableColumnarBatch and closes it; if the
// caller asked us not to close `sb`, increment the ref count to
// compensate.
val sbForRetry = if (closeInputBatch) sb else sb.incRefCount()
GpuProjectExec.runWithSplitRetry(sbForRetry, retryables, project(_))
} else {
None
}
withResource(sbToClose) { _ =>
retryables.foreach(_.checkpoint())
RmmRapidsRetryIterator.withRetryNoSplit {
withResource(sb.getColumnarBatch()) { cb =>
withRestoreOnRetry(retryables) {
project(cb)
val sbToClose = if (closeInputBatch) {
Some(sb)
} else {
None
}
withResource(sbToClose) { _ =>
retryables.foreach(_.checkpoint())
RmmRapidsRetryIterator.withRetryNoSplit {
withResource(sb.getColumnarBatch()) { cb =>
withRestoreOnRetry(retryables) {
project(cb)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq
import com.nvidia.spark.rapids.jni.{GpuSplitAndRetryOOM, RmmSpark}

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId, NamedExpression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuAdd
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.vectorized.ColumnarBatch

class ProjectSplitRetrySuite extends RmmSparkRetrySuiteBase {
private val NUM_ROWS = 500
private val RAND_SEED = 10
private val intAttr = AttributeReference("int", IntegerType)(ExprId(10))
private val batchAttrs = Seq(intAttr)

// Reset retry counters so a leaked count from one test cannot mask a
// missed injection in the next.
override def afterEach(): Unit = {
RmmSpark.getAndResetNumRetryThrow(/*taskId*/ 1)
RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1)
super.afterEach()
}

private def buildBatch(): ColumnarBatch = {
val ints = 0 until NUM_ROWS
new ColumnarBatch(
Array(GpuColumnVector.from(ColumnVector.fromInts(ints: _*), IntegerType)),
ints.length)
}

private def newSpillable(): SpillableColumnarBatch =
SpillableColumnarBatch(buildBatch(), SpillPriorities.ACTIVE_ON_DECK_PRIORITY)

// GpuAdd(int, 1) — pure, deterministic, retryable.
private def addOneExprs(): Seq[GpuExpression] = Seq(
GpuAlias(GpuAdd(
GpuBoundReference(0, IntegerType, true)(NamedExpression.newExprId, "int"),
GpuLiteral(1, IntegerType),
failOnError = false)(),
"plus_one")())

private def collectInts(cb: ColumnarBatch, col: Int): Array[Int] = {
val gcv = cb.column(col).asInstanceOf[GpuColumnVector]
withResource(gcv.copyToHost()) { hcv =>
(0 until cb.numRows()).map(hcv.getInt).toArray
}
}

private def collectDoubles(cb: ColumnarBatch, col: Int): Array[Double] = {
val gcv = cb.column(col).asInstanceOf[GpuColumnVector]
withResource(gcv.copyToHost()) { hcv =>
(0 until cb.numRows()).map(hcv.getDouble).toArray
}
}

// Helper: build the SpillableColumnarBatch BEFORE injecting the OOM, so
// that the alloc inside ColumnVector.fromInts doesn't accidentally absorb
// the injection. Only the projection itself should trip the OOM.
private def withInjectedOOM[T](inject: => Unit)(body: SpillableColumnarBatch => T): T = {
val sb = newSpillable()
inject
body(sb)
}

test("split-retry produces same output as a single-batch projection") {
val out = withInjectedOOM {
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
} { sb =>
GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, addOneExprs())
}
withResource(out) { cb =>
assertResult(NUM_ROWS)(cb.numRows())
assertResult(1)(cb.numCols())
val got = collectInts(cb, 0)
(0 until NUM_ROWS).foreach { i =>
assertResult(i + 1)(got(i))
}
}
assert(RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1) > 0,
"expected at least one SplitAndRetryOOM to have been observed")
}

test("conf=false routes split-retry OOM to legacy path which fails") {
val sqlConf = new SQLConf()
sqlConf.setConfString(RapidsConf.PROJECT_SPLIT_RETRY_ENABLED.key, "false")
SQLConf.withExistingConf(sqlConf) {
val sb = newSpillable()
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
assertThrows[GpuSplitAndRetryOOM] {
GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, addOneExprs()).close()
}
}
}

test("tiered project split-retry produces correct output") {
val tier = GpuBindReferences.bindGpuReferencesTiered(
addOneExprs(), batchAttrs, new SQLConf(), Map.empty)
assert(tier.areAllRetryable && tier.areAllDeterministic)
val out = withInjectedOOM {
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
} { sb =>
tier.projectAndCloseWithRetrySingleBatch(sb)
}
withResource(out) { cb =>
assertResult(NUM_ROWS)(cb.numRows())
val got = collectInts(cb, 0)
(0 until NUM_ROWS).foreach { i =>
assertResult(i + 1)(got(i))
}
}
assert(RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1) > 0)
}

// A non-deterministic projection (containing GpuRand) must NOT take the
// split path even when the conf is on, because row-splitting would
// change rand state across the halves and break the row-aligned stitch
// between deterministic and rand columns. forceRetryOOM (plain, not
// split) verifies the legacy withRetryNoSplit path is selected and
// checkpoint/restore reproduces the rand sequence on retry.
test("mixed deterministic + GpuRand falls back to legacy retry path") {
def projection(): Seq[GpuExpression] = Seq(
GpuAlias(GpuAdd(
GpuBoundReference(0, IntegerType, true)(NamedExpression.newExprId, "int"),
GpuLiteral(1, IntegerType),
failOnError = false)(), "plus_one")(),
GpuAlias(GpuRand(GpuLiteral(RAND_SEED, IntegerType), false), "rnd")())

val batches = Seq(true, false).safeMap { forceRetry =>
val tier = GpuBindReferences.bindGpuReferencesTiered(
projection(), batchAttrs, new SQLConf(), Map.empty)
assert(tier.areAllRetryable && !tier.areAllDeterministic)
val sb = newSpillable()
if (forceRetry) {
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
}
tier.projectAndCloseWithRetrySingleBatch(sb)
}
withResource(batches) { case Seq(retried, ref) =>
assertResult(ref.numRows())(retried.numRows())
assertResult(ref.numCols())(retried.numCols())
val refPlus = collectInts(ref, 0)
val retPlus = collectInts(retried, 0)
val refRand = collectDoubles(ref, 1)
val retRand = collectDoubles(retried, 1)
(0 until NUM_ROWS).foreach { i =>
assertResult(refPlus(i))(retPlus(i))
assertResult(refRand(i))(retRand(i))
}
}
}

// A plain GpuRetryOOM under the new path is resolved before the splitter
// is invoked, so the result comes back as a single piece — exercising the
// single-piece path through ConcatAndConsumeAll.buildNonEmptyBatchFromTypes.
test("plain GpuRetryOOM under split-retry path returns a single piece") {
val out = withInjectedOOM {
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 1,
RmmSpark.OomInjectionType.GPU.ordinal, 0)
} { sb =>
GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, addOneExprs())
}
withResource(out) { cb =>
assertResult(NUM_ROWS)(cb.numRows())
val got = collectInts(cb, 0)
(0 until NUM_ROWS).foreach { i =>
assertResult(i + 1)(got(i))
}
}
assertResult(0)(RmmSpark.getAndResetNumSplitRetryThrow(/*taskId*/ 1))
assert(RmmSpark.getAndResetNumRetryThrow(/*taskId*/ 1) > 0)
}
}