From 1d34ced9884e0a94977d82f6c83c3e07c0cb03a8 Mon Sep 17 00:00:00 2001 From: micheal-o Date: Thu, 13 Nov 2025 20:51:44 -0800 Subject: [PATCH 1/9] runner --- .../resources/error/error-conditions.json | 79 +++++ .../StreamingCheckpointManager.scala | 50 +++ .../apache/spark/sql/classic/SQLContext.scala | 6 + .../spark/sql/classic/SparkSession.scala | 2 + .../classic/StreamingCheckpointManager.scala | 55 ++++ .../v2/state/StateDataSource.scala | 12 +- .../execution/streaming/StreamingUtils.scala | 28 ++ .../OfflineStateRepartitionBatchRunner.scala | 255 +++++++++++++++ .../state/OfflineStateRepartitionErrors.scala | 203 ++++++++++++ .../internal/BaseSessionStateBuilder.scala | 9 +- .../spark/sql/internal/SessionState.scala | 6 +- .../state/OfflineStateRepartitionSuite.scala | 301 ++++++++++++++++++ 12 files changed, 995 insertions(+), 11 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f7eb1e63d7bd..7757a66d339e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5229,6 +5229,85 @@ ], "sqlState" : "0A000" }, + "STATE_REPARTITION_INVALID_CHECKPOINT" : { + "message" : [ + "The provided checkpoint location '' is in an invalid state." + ], + "subClass" : { + "LAST_BATCH_ABANDONED_REPARTITION" : { + "message" : [ + "The last batch ID is a repartition batch with shuffle partitions and didn't finish successfully.", + "You're now requesting to repartition to shuffle partitions.", + "Please retry with the same number of shuffle partitions as the previous attempt.", + "Once that completes successfully, you can repartition to another number of shuffle partitions." + ] + }, + "LAST_BATCH_FAILED" : { + "message" : [ + "The last batch ID didn't finish successfully. Please make sure the streaming query finishes successfully, before repartitioning.", + "If using ProcessingTime trigger, you can use AvailableNow trigger instead, which will make sure the query terminates successfully by itself.", + "If you want to skip this check, set enforceExactlyOnceSink parameter in repartition to false.", + "But this can cause duplicate output records from the failed batch when using exactly-once sinks." + ] + }, + "MISSING_OFFSET_SEQ_METADATA" : { + "message" : [ + "The OffsetSeq (v) metadata is missing for batch ID . Please make sure the checkpoint is from a supported Spark version." + ] + }, + "NO_BATCH_FOUND" : { + "message" : [ + "No microbatch has been recorded in the checkpoint location. Make sure the streaming query has successfully completed at least one microbatch before repartitioning." + ] + }, + "NO_COMMITTED_BATCH" : { + "message" : [ + "There is no committed microbatch. Make sure the streaming query has successfully completed at least one microbatch before repartitioning." + ] + }, + "OFFSET_SEQ_NOT_FOUND" : { + "message" : [ + "Offset sequence entry for batch ID not found. You might have set a very low value for", + "'spark.sql.streaming.minBatchesToRetain' config during the streaming query execution or you deleted files in the checkpoint location." + ] + }, + "SHUFFLE_PARTITIONS_ALREADY_MATCH" : { + "message" : [ + "The number of shuffle partitions in the last committed batch (id=) is the same as the requested partitions.", + "Hence, already has the requested number of partitions, so no-op." + ] + }, + "UNSUPPORTED_OFFSET_SEQ_VERSION" : { + "message" : [ + "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version." + ] + } + }, + "sqlState" : "55019" + }, + "STATE_REPARTITION_INVALID_PARAMETER" : { + "message" : [ + "The repartition parameter is invalid:" + ], + "subClass" : { + "IS_EMPTY" : { + "message" : [ + "cannot be empty." + ] + }, + "IS_NOT_GREATER_THAN_ZERO" : { + "message" : [ + "must be greater than zero." + ] + }, + "IS_NULL" : { + "message" : [ + "cannot be null." + ] + } + }, + "sqlState" : "42616" + }, "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { "message" : [ "Failed to perform stateful processor operation= with invalid handle state=." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala new file mode 100644 index 000000000000..dfa6437b59e1 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +/** + * A class to manage operations on streaming query checkpoints. + */ +private[spark] abstract class StreamingCheckpointManager { + + /** + * Repartition the stateful streaming operators state in the streaming checkpoint to have + * `numPartitions` partitions. The streaming query MUST not be running. + * If `numPartitions` is the same as the current number of partitions, this is a no-op, + * and an exception will be thrown. + * + * This produces a new microbatch in the checkpoint that contains the repartitioned state i.e. + * if the last streaming batch was batch `N`, this will create batch `N+1` with the repartitioned + * state. Note that this new batch doesn't read input data from sources, it only represents the + * repartition operation. The next time the streaming query is started, it will pick up from + * this new batch. + * + * This will return only when the repartitioning is complete or fails. + * + * @note This operation should only be performed after the streaming query has been stopped. + * @param checkpointLocation The checkpoint location of the streaming query, should be the + * `checkpointLocation` option on the DataStreamWriter. + * @param numPartitions the target number of state partitions. + * @param enforceExactlyOnceSink if we shouldn't allow skipping failed batches, + * to avoid duplicates in exactly once sinks. + */ + private[spark] def repartition( + checkpointLocation: String, + numPartitions: Int, + enforceExactlyOnceSink: Boolean = true): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala index 18a84d8c4299..711a64b4589a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala @@ -145,6 +145,12 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) */ def streams: StreamingQueryManager = sparkSession.streams + /** + * Returns a `StreamingCheckpointManager` that allows managing any streaming checkpoint. + */ + private[spark] def streamingCheckpointManager: StreamingCheckpointManager = + sparkSession.streamingCheckpointManager + /** @inheritdoc */ override def sparkContext: SparkContext = super.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index f7876d9a023b..1ac9941c07ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -223,6 +223,8 @@ class SparkSession private( @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager + private[spark] def streamingCheckpointManager = sessionState.streamingCheckpointManager + /** * Returns an `ArtifactManager` that supports adding, managing and using session-scoped artifacts * (jars, classfiles, etc). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala new file mode 100644 index 000000000000..ebe0f5b572dc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionBatchRunner, OfflineStateRepartitionErrors} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming + +/** @inheritdoc */ +private[spark] class StreamingCheckpointManager( + sparkSession: SparkSession, + sqlConf: SQLConf) extends streaming.StreamingCheckpointManager with Logging { + + /** @inheritdoc */ + override private[spark] def repartition( + checkpointLocation: String, + numPartitions: Int, + enforceExactlyOnceSink: Boolean = true): Unit = { + checkpointLocation match { + case null => + throw OfflineStateRepartitionErrors.parameterIsNullError("checkpointLocation") + case "" => + throw OfflineStateRepartitionErrors.parameterIsEmptyError("checkpointLocation") + case _ => // Valid case, no action needed + } + + if (numPartitions <= 0) { + throw OfflineStateRepartitionErrors.parameterIsNotGreaterThanZeroError("numPartitions") + } + + val runner = new OfflineStateRepartitionBatchRunner( + sparkSession, + checkpointLocation, + numPartitions, + enforceExactlyOnceSink + ) + runner.run() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index fea4d345b8d0..af478aa5c97c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{J import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.StreamingUtils import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} @@ -481,7 +482,8 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } - val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation) + val resolvedCpLocation = StreamingUtils.resolvedCheckpointLocation( + hadoopConf, checkpointLocation) var batchId = Option(options.get(BATCH_ID)).map(_.toLong) @@ -617,14 +619,6 @@ object StateSourceOptions extends DataSourceOptions { startOperatorStateUniqueIds, endOperatorStateUniqueIds) } - private def resolvedCheckpointLocation( - hadoopConf: Configuration, - checkpointLocation: String): String = { - val checkpointPath = new Path(checkpointLocation) - val fs = checkpointPath.getFileSystem(hadoopConf) - checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString - } - private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog commitLog.getLatest() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala new file mode 100644 index 000000000000..818c9e99e1d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +object StreamingUtils { + def resolvedCheckpointLocation(hadoopConf: Configuration, checkpointLocation: String): String = { + val checkpointPath = new Path(checkpointLocation) + val fs = checkpointPath.getFileSystem(hadoopConf) + checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala new file mode 100644 index 000000000000..ec4df3ed06b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.streaming.StreamingUtils +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata +import org.apache.spark.sql.internal.SQLConf + +/** + * Offline batch runner for repartitioning the state stores used by a streaming query. + * + * This class handles the process of creating a new microbatch, repartitioning state data + * across new partitions, and committing the changes to the checkpoint i.e. + * if the last streaming batch was batch `N`, this will create batch `N+1` with the repartitioned + * state. Note that this new batch doesn't read input data from sources, it only represents the + * repartition operation. The next time the streaming query is started, it will pick up from + * this new batch. + * + * @param sparkSession The active Spark session + * @param checkpointLocation The checkpoint location path + * @param numPartitions The new number of partitions to repartition to + * @param enforceExactlyOnceSink if we shouldn't allow skipping failed batches, + * to avoid duplicates in exactly once sinks. + */ +class OfflineStateRepartitionBatchRunner( + sparkSession: SparkSession, + checkpointLocation: String, + numPartitions: Int, + enforceExactlyOnceSink: Boolean = true) extends Logging { + + import OfflineStateRepartitionUtils._ + + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + + private val resolvedCpLocation = StreamingUtils.resolvedCheckpointLocation( + hadoopConf, checkpointLocation) + + private val checkpointMetadata = new StreamingQueryCheckpointMetadata( + sparkSession, resolvedCpLocation) + + /** + * Runs a repartitioning batch and returns the batch ID. + * This will only return when the repartitioning is done. + * + * @return The repartition batch ID + */ + def run(): Long = { + logInfo(log"Starting offline state repartitioning for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}") + + val lastCommittedBatchId = getLastCommittedBatchId() + val lastBatchId = getLastBatchId() + + val newBatchId = createNewBatchIfNeeded(lastBatchId, lastCommittedBatchId) + + // todo: Do the repartitioning here, in subsequent PR + + // todo: update operator metadata in subsequent PR. + + // Commit the repartition batch + commitBatch(newBatchId, lastCommittedBatchId) + + logInfo(log"Completed state repartitioning with new " + + log"batchId=${MDC(BATCH_ID, newBatchId)}") + + newBatchId + } + + private def getLastCommittedBatchId(): Long = { + checkpointMetadata.commitLog.getLatestBatchId() match { + case Some(id) => id + // Needs at least 1 committed batch to repartition + case None => throw OfflineStateRepartitionErrors.noCommittedBatchError(checkpointLocation) + } + } + + private def getLastBatchId(): Long = { + checkpointMetadata.offsetLog.getLatestBatchId() match { + case Some(id) => id + case None => throw OfflineStateRepartitionErrors.noBatchFoundError(checkpointLocation) + } + } + + private def createNewBatchIfNeeded(lastBatchId: Long, lastCommittedBatchId: Long): Long = { + if (lastBatchId == lastCommittedBatchId) { + // Means there are no uncommitted batches. So start a new batch. + createNewBatchFromLastCommitted(lastBatchId, lastCommittedBatchId) + } else { + // Means there are uncommitted batches. + if (isRepartitionBatch(lastBatchId, checkpointMetadata.offsetLog, checkpointLocation)) { + // If it is a failed repartition batch, lets check if the shuffle partitions + // is the same as the requested. If same, then we can retry the batch. + val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get + val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + if (lastBatchShufflePartitions == numPartitions) { + // We can retry the repartition batch. + logInfo(log"The last batch is a failed repartition batch " + + log"(batchId=${MDC(BATCH_ID, lastBatchId)}). " + + log"Retrying it since it used the same number of shuffle partitions " + + log"as the requested ${MDC(NUM_PARTITIONS, numPartitions)}.") + lastBatchId + } else { + // Failed repartition should be retried with the same number of shuffle partitions. + // Once that completes successfully, then can repartition to another number + // of shuffle partitions. + throw OfflineStateRepartitionErrors.lastBatchAbandonedRepartitionError( + checkpointLocation, lastBatchId, lastBatchShufflePartitions, numPartitions) + } + } else { + if (enforceExactlyOnceSink) { + // We want the last batch to have committed successfully. + // Before proceeding with repartitioning, since repartitioning produces a new batch. + // If we skip the unsuccessful batch, this can cause duplicates in exactly-once sinks + // which uses the batchId to track already committed data. + throw OfflineStateRepartitionErrors.lastBatchFailedError(checkpointLocation, lastBatchId) + } else { + // We can skip the uncommitted batches. And repartition using the last committed + // batch state. Note that input data from the skipped failed batch will be reprocessed + // in the next query run. + skipUncommittedBatches(lastBatchId, lastCommittedBatchId) + // Now create a new batch + createNewBatchFromLastCommitted(lastBatchId, lastCommittedBatchId) + } + } + } + } + + private def skipUncommittedBatches(lastBatchId: Long, lastCommittedBatchId: Long): Unit = { + assert(lastBatchId > lastCommittedBatchId, + "Last batch ID must be greater than last committed batch ID") + + val fromBatchId = lastCommittedBatchId + 1 + for (batchId <- fromBatchId to lastBatchId) { + // write empty commit for these skipped batches + if (!checkpointMetadata.commitLog.add(batchId, CommitMetadata())) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(batchId) + } + } + + logInfo(log"Skipped uncommitted batches from batchId " + + log"${MDC(BATCH_ID, fromBatchId)} to ${MDC(BATCH_ID, lastBatchId)}") + } + + /** + * Creates a new offset log entry for the repartition batch using the OffsetSeq + * of the last committed batch. But with a new number of partitions. + */ + private def createNewBatchFromLastCommitted( + lastBatchId: Long, + lastCommittedBatchId: Long): Long = { + val newBatchId = lastBatchId + 1 + // We want to repartition the state as of the last committed batch. + val lastCommittedOffsetSeq = checkpointMetadata.offsetLog.get(lastCommittedBatchId) + .getOrElse(throw OfflineStateRepartitionErrors + .offsetSeqNotFoundError(checkpointLocation, lastCommittedBatchId)) + + // Missing offset metadata not supported + val lastCommittedMetadata = lastCommittedOffsetSeq.metadata.getOrElse( + throw OfflineStateRepartitionErrors.missingOffsetSeqMetadataError( + checkpointLocation, version = 1, batchId = lastCommittedBatchId) + ) + + // No-op if the number of shuffle partitions in last commit is the same as the requested. + if (getShufflePartitions(lastCommittedMetadata).get == numPartitions) { + throw OfflineStateRepartitionErrors.shufflePartitionsAlreadyMatchError( + checkpointLocation, lastCommittedBatchId, numPartitions) + } + + // Create a new OffsetSeq from the last committed but with an update num shuffle partitions + val newOffsetSeq = lastCommittedOffsetSeq match { + case v1: OffsetSeq => + val metadata = v1.metadata.get + v1.copy(metadata = Some(metadata.copy( + conf = metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString)))) + case _ => throw OfflineStateRepartitionErrors.unsupportedOffsetSeqVersionError( + checkpointLocation, version = -1) + } + + // Will fail if there is a concurrent operation on going + if (!checkpointMetadata.offsetLog.add(newBatchId, newOffsetSeq)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId) + } + + logInfo(log"Created new offset log entry for repartition batch. " + + log"batchId=${MDC(BATCH_ID, newBatchId)}") + + newBatchId + } + + private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit = { + val latestCommit = checkpointMetadata.commitLog.get(lastCommittedBatchId).get + + // todo: For checkpoint v2, we need to update the stateUniqueIds based on the + // newly created state commit. Will be done in subsequent PR. + if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId) + } + } +} + +object OfflineStateRepartitionUtils { + def isRepartitionBatch( + batchId: Long, offsetLog: OffsetSeqLog, checkpointLocation: String): Boolean = { + assert(batchId >= 0, "Batch ID must be non-negative") + batchId match { + // first batch can never be a repartition batch since we require at least one committed batch + case 0 => false + case _ => + // A repartition batch is a batch where the number of shuffle partitions changed + // compared to the previous batch. + val batch = offsetLog.get(batchId).getOrElse(throw OfflineStateRepartitionErrors + .offsetSeqNotFoundError(checkpointLocation, batchId)) + val prevBatchId = batchId - 1 + val previousBatch = offsetLog.get(prevBatchId).getOrElse( + throw OfflineStateRepartitionErrors + .offsetSeqNotFoundError(checkpointLocation, prevBatchId)) + + val batchMetadata = batch.metadata.getOrElse(throw OfflineStateRepartitionErrors + .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = batchId)) + val shufflePartitions = getShufflePartitions(batchMetadata).get + + val previousBatchMetadata = previousBatch.metadata.getOrElse( + throw OfflineStateRepartitionErrors + .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = prevBatchId)) + val previousShufflePartitions = getShufflePartitions(previousBatchMetadata).get + + previousShufflePartitions != shufflePartitions + } + } + + def getShufflePartitions(metadata: OffsetSeqMetadata): Option[Int] = { + metadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).map(_.toInt) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala new file mode 100644 index 000000000000..95b273826877 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} + +/** + * Errors thrown by Offline state repartitioning. + */ +object OfflineStateRepartitionErrors { + def parameterIsEmptyError(parameter: String): StateRepartitionInvalidParameterError = { + new StateRepartitionParameterIsEmptyError(parameter) + } + + def parameterIsNotGreaterThanZeroError( + parameter: String): StateRepartitionInvalidParameterError = { + new StateRepartitionParameterIsNotGreaterThanZeroError(parameter) + } + + def parameterIsNullError(parameter: String): StateRepartitionInvalidParameterError = { + new StateRepartitionParameterIsNullError(parameter) + } + + def lastBatchAbandonedRepartitionError( + checkpointLocation: String, + lastBatchId: Long, + lastBatchShufflePartitions: Int, + numPartitions: Int): StateRepartitionInvalidCheckpointError = { + new StateRepartitionLastBatchAbandonedRepartitionError( + checkpointLocation, lastBatchId, lastBatchShufflePartitions, numPartitions) + } + + def lastBatchFailedError( + checkpointLocation: String, + lastBatchId: Long): StateRepartitionInvalidCheckpointError = { + new StateRepartitionLastBatchFailedError(checkpointLocation, lastBatchId) + } + + def missingOffsetSeqMetadataError( + checkpointLocation: String, + version: Int, + batchId: Long): StateRepartitionInvalidCheckpointError = { + new StateRepartitionMissingOffsetSeqMetadataError(checkpointLocation, version, batchId) + } + + def noBatchFoundError(checkpointLocation: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionNoBatchFoundError(checkpointLocation) + } + + def noCommittedBatchError(checkpointLocation: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionNoCommittedBatchError(checkpointLocation) + } + + def offsetSeqNotFoundError( + checkpointLocation: String, + batchId: Long): StateRepartitionInvalidCheckpointError = { + new StateRepartitionOffsetSeqNotFoundError(checkpointLocation, batchId) + } + + def shufflePartitionsAlreadyMatchError( + checkpointLocation: String, + batchId: Long, + numPartitions: Int): StateRepartitionInvalidCheckpointError = { + new StateRepartitionShufflePartitionsAlreadyMatchError( + checkpointLocation, batchId, numPartitions) + } + + def unsupportedOffsetSeqVersionError( + checkpointLocation: String, + version: Int): StateRepartitionInvalidCheckpointError = { + new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation, version) + } +} + +/** + * Base class for exceptions thrown when an invalid parameter is passed + * into the repartition operation. + */ +abstract class StateRepartitionInvalidParameterError( + parameter: String, + subClass: String, + messageParameters: Map[String, String] = Map.empty, + cause: Throwable = null) + extends SparkIllegalArgumentException( + errorClass = s"STATE_REPARTITION_INVALID_PARAMETER.$subClass", + messageParameters = Map("parameter" -> parameter) ++ messageParameters, + cause = cause) + +class StateRepartitionParameterIsEmptyError(parameter: String) + extends StateRepartitionInvalidParameterError( + parameter, + subClass = "IS_EMPTY") + +class StateRepartitionParameterIsNotGreaterThanZeroError(parameter: String) + extends StateRepartitionInvalidParameterError( + parameter, + subClass = "IS_NOT_GREATER_THAN_ZERO") + +class StateRepartitionParameterIsNullError(parameter: String) + extends StateRepartitionInvalidParameterError( + parameter, + subClass = "IS_NULL") + +/** + * Base class for exceptions thrown when the checkpoint location is in an invalid state + * for repartitioning. + */ +abstract class StateRepartitionInvalidCheckpointError( + checkpointLocation: String, + subClass: String, + messageParameters: Map[String, String], + cause: Throwable = null) + extends SparkIllegalStateException( + errorClass = s"STATE_REPARTITION_INVALID_CHECKPOINT.$subClass", + messageParameters = Map("checkpointLocation" -> checkpointLocation) ++ messageParameters, + cause = cause) + +class StateRepartitionLastBatchAbandonedRepartitionError( + checkpointLocation: String, + lastBatchId: Long, + lastBatchShufflePartitions: Int, + numPartitions: Int) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "LAST_BATCH_ABANDONED_REPARTITION", + messageParameters = Map( + "lastBatchId" -> lastBatchId.toString, + "lastBatchShufflePartitions" -> lastBatchShufflePartitions.toString, + "numPartitions" -> numPartitions.toString + )) + +class StateRepartitionLastBatchFailedError( + checkpointLocation: String, + lastBatchId: Long) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "LAST_BATCH_FAILED", + messageParameters = Map("lastBatchId" -> lastBatchId.toString)) + +class StateRepartitionMissingOffsetSeqMetadataError( + checkpointLocation: String, + version: Int, + batchId: Long) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "MISSING_OFFSET_SEQ_METADATA", + messageParameters = Map("version" -> version.toString, "batchId" -> batchId.toString)) + +class StateRepartitionNoBatchFoundError( + checkpointLocation: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "NO_BATCH_FOUND", + messageParameters = Map.empty) + +class StateRepartitionNoCommittedBatchError( + checkpointLocation: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "NO_COMMITTED_BATCH", + messageParameters = Map.empty) + +class StateRepartitionOffsetSeqNotFoundError( + checkpointLocation: String, + batchId: Long) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "OFFSET_SEQ_NOT_FOUND", + messageParameters = Map("batchId" -> batchId.toString)) + +class StateRepartitionShufflePartitionsAlreadyMatchError( + checkpointLocation: String, + batchId: Long, + numPartitions: Int) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "SHUFFLE_PARTITIONS_ALREADY_MATCH", + messageParameters = Map( + "batchId" -> batchId.toString, + "numPartitions" -> numPartitions.toString)) + +class StateRepartitionUnsupportedOffsetSeqVersionError( + checkpointLocation: String, + version: Int) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION", + messageParameters = Map("version" -> version.toString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index c967497b660c..ff6e58c2b2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.classic.{SparkSession, Strategy, StreamingQueryManager, UDFRegistration} +import org.apache.spark.sql.classic.{SparkSession, Strategy, StreamingCheckpointManager, StreamingQueryManager, UDFRegistration} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} @@ -418,6 +418,12 @@ abstract class BaseSessionStateBuilder( protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session, conf) + /** + * Interface to manage streaming query checkpoints. + */ + private[spark] def streamingCheckpointManager: StreamingCheckpointManager = + new StreamingCheckpointManager(session, conf) + /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -465,6 +471,7 @@ abstract class BaseSessionStateBuilder( () => optimizer, planner, () => streamingQueryManager, + () => streamingCheckpointManager, listenerManager, () => resourceLoader, createQueryExecution, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 440148989ffb..2e921d0054e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.classic.{SparkSession, StreamingQueryManager, UDFRegistration} +import org.apache.spark.sql.classic.{SparkSession, StreamingCheckpointManager, StreamingQueryManager, UDFRegistration} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder @@ -84,6 +84,7 @@ private[sql] class SessionState( optimizerBuilder: () => Optimizer, val planner: SparkPlanner, val streamingQueryManagerBuilder: () => StreamingQueryManager, + val streamingCheckpointManagerBuilder: () => StreamingCheckpointManager, val listenerManager: ExecutionListenerManager, resourceLoaderBuilder: () => SessionResourceLoader, createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution, @@ -106,6 +107,9 @@ private[sql] class SessionState( // when connecting to ThriftServer. lazy val streamingQueryManager: StreamingQueryManager = streamingQueryManagerBuilder() + private[spark] lazy val streamingCheckpointManager: StreamingCheckpointManager = + streamingCheckpointManagerBuilder() + lazy val artifactManager: ArtifactManager = artifactManagerBuilder() def catalogManager: CatalogManager = analyzer.catalogManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala new file mode 100644 index 000000000000..eadc26e21e21 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ + +/** + * Test for offline state repartitioning. This tests that repartition behaves as expected + * for different scenarios. + */ +class OfflineStateRepartitionSuite extends StreamTest { + import testImplicits._ + import OfflineStateRepartitionUtils._ + + test("Fail if empty checkpoint directory") { + withTempDir { dir => + val ex = intercept[StateRepartitionNoCommittedBatchError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 5) + } + + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.NO_COMMITTED_BATCH", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath + ) + ) + } + } + + test("Fail if no batch found in checkpoint directory") { + withTempDir { dir => + // Write commit log but no offset log. + val commitLog = new CommitLog(spark, dir.getCanonicalPath + "/commits") + commitLog.add(0, CommitMetadata()) + + val ex = intercept[StateRepartitionNoBatchFoundError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 5) + } + + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.NO_BATCH_FOUND", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath + ) + ) + } + } + + test("Fails if repartition parameter is invalid") { + val ex1 = intercept[StateRepartitionParameterIsNullError] { + spark.streamingCheckpointManager.repartition(null, 5) + } + + checkError( + ex1, + condition = "STATE_REPARTITION_INVALID_PARAMETER.IS_NULL", + parameters = Map("parameter" -> "checkpointLocation") + ) + + val ex2 = intercept[StateRepartitionParameterIsEmptyError] { + spark.streamingCheckpointManager.repartition("", 5) + } + + checkError( + ex2, + condition = "STATE_REPARTITION_INVALID_PARAMETER.IS_EMPTY", + parameters = Map("parameter" -> "checkpointLocation") + ) + + val ex3 = intercept[StateRepartitionParameterIsNotGreaterThanZeroError] { + spark.streamingCheckpointManager.repartition("test", 0) + } + + checkError( + ex3, + condition = "STATE_REPARTITION_INVALID_PARAMETER.IS_NOT_GREATER_THAN_ZERO", + parameters = Map("parameter" -> "numPartitions") + ) + } + + test("Repartition: success, failure, retry") { + withTempDir { dir => + val originalPartitions = 3 + val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath) + val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + // Shouldn't be seen as a repartition batch + assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, dir.getAbsolutePath)) + + // Trying to repartition to the same number should fail + val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions) + } + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath, + "batchId" -> batchId.toString, + "numPartitions" -> originalPartitions.toString + ) + ) + + // Trying to repartition to a different number should succeed + val newPartitions = originalPartitions + 1 + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions) + val repartitionBatchId = batchId + 1 + verifyRepartitionBatch( + repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, newPartitions) + + // Now delete the repartition commit to simulate a failed repartition attempt. + // This will delete all the commits after the batchId. + checkpointMetadata.commitLog.purgeAfter(batchId) + + // Try to repartition with a different numPartitions should fail, + // since it will see an uncommitted repartition batch with a different numPartitions. + val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions + 1) + } + checkError( + ex2, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath, + "lastBatchId" -> repartitionBatchId.toString, + "lastBatchShufflePartitions" -> newPartitions.toString, + "numPartitions" -> (newPartitions + 1).toString + ) + ) + + // Retrying with the same numPartitions should work + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions) + verifyRepartitionBatch( + repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, newPartitions) + } + } + + test("Query last batch failed before repartitioning") { + withTempDir { dir => + val originalPartitions = 3 + // Run the query twice to produce two batches + val input = MemoryStream[Int] + val firstBatchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) + val lastBatchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) + val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + + // lets delete the last batch commit to simulate last query batch failed + checkpointMetadata.commitLog.purgeAfter(firstBatchId) + + // Now repartitioning should fail + val ex = intercept[StateRepartitionLastBatchFailedError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions + 1) + } + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_FAILED", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath, + "lastBatchId" -> lastBatchId.toString + ) + ) + + // Setting enforceExactlyOnceSink to false should allow repartitioning + spark.streamingCheckpointManager.repartition( + dir.getAbsolutePath, originalPartitions + 1, enforceExactlyOnceSink = false) + verifyRepartitionBatch( + lastBatchId + 1, + checkpointMetadata, + dir.getAbsolutePath, + originalPartitions + 1, + // Repartition should be based on the first batch, since we skipped the last one + baseBatchId = Some(firstBatchId)) + } + } + + test("Consecutive repartition") { + withTempDir { dir => + val originalPartitions = 3 + val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath) + + val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + + // decrease + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions - 1) + verifyRepartitionBatch( + batchId + 1, + checkpointMetadata, + dir.getAbsolutePath, + originalPartitions - 1 + ) + + // increase + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions + 1) + verifyRepartitionBatch( + batchId + 2, + checkpointMetadata, + dir.getAbsolutePath, + originalPartitions + 1 + ) + } + } + + private def runSimpleStreamQuery( + numPartitions: Int, + checkpointLocation: String, + input: MemoryStream[Int] = MemoryStream[Int]): Long = { + val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) + + var committedBatchId: Long = -1 + testStream(input.toDF().groupBy().count(), outputMode = OutputMode.Update)( + StartStream(checkpointLocation = checkpointLocation, additionalConfs = conf), + AddData(input, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + committedBatchId = Option(query.lastProgress).map(_.batchId).getOrElse(-1) + } + ) + + assert(committedBatchId >= 0, "No batch was committed in the streaming query") + committedBatchId + } + + private def verifyRepartitionBatch( + batchId: Long, + checkpointMetadata: StreamingQueryCheckpointMetadata, + checkpointLocation: String, + expectedShufflePartitions: Int, + baseBatchId: Option[Long] = None): Unit = { + // Should be seen as a repartition batch + assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog, checkpointLocation)) + + // Verify the repartition batch + val lastBatchId = checkpointMetadata.offsetLog.getLatestBatchId().get + assert(lastBatchId == batchId) + + val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get + val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + assert(lastBatchShufflePartitions == expectedShufflePartitions) + + // Verify the commit log + val lastCommitId = checkpointMetadata.commitLog.getLatestBatchId().get + assert(lastCommitId == batchId) + + // verify that the offset seq is the same between repartition batch and + // the batch the repartition is based on except for the shuffle partitions. + // When failed batches are skipped, then repartition can be based + // on an older batch and not batchId - 1. + val previousBatchId = baseBatchId.getOrElse(batchId - 1) + val previousBatch = checkpointMetadata.offsetLog.get(previousBatchId).get + + // Verify offsets are identical + assert(lastBatch.offsets == previousBatch.offsets, + s"Offsets should be identical between batch $previousBatchId and $batchId") + + // Verify metadata is the same except for shuffle partitions config + (lastBatch.metadata, previousBatch.metadata) match { + case (Some(lastMetadata), Some(previousMetadata)) => + // Check watermark and timestamp are the same + assert(lastMetadata.batchWatermarkMs == previousMetadata.batchWatermarkMs, + "Batch watermark should be the same") + assert(lastMetadata.batchTimestampMs == previousMetadata.batchTimestampMs, + "Batch timestamp should be the same") + + // Check all configs are the same except shuffle partitions + val lastConfWithoutShufflePartitions = + lastMetadata.conf - SQLConf.SHUFFLE_PARTITIONS.key + val previousConfWithoutShufflePartitions = + previousMetadata.conf - SQLConf.SHUFFLE_PARTITIONS.key + assert(lastConfWithoutShufflePartitions == previousConfWithoutShufflePartitions, + "All configs except shuffle partitions should be the same") + + // Verify shuffle partitions are different + assert( + getShufflePartitions(lastMetadata).get != getShufflePartitions(previousMetadata).get, + "Shuffle partitions should be different between batches") + case _ => + fail("Both batches should have metadata") + } + } +} From c7eb1cde201947d0843e5829fdca0dc3f5f17dfa Mon Sep 17 00:00:00 2001 From: micheal-o Date: Fri, 14 Nov 2025 11:17:31 -0800 Subject: [PATCH 2/9] err clss formating --- .../resources/error/error-conditions.json | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7757a66d339e..6b1e7f9c180a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5229,6 +5229,42 @@ ], "sqlState" : "0A000" }, + "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { + "message" : [ + "Failed to perform stateful processor operation= with invalid handle state=." + ], + "sqlState" : "42802" + }, + "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIME_MODE" : { + "message" : [ + "Failed to perform stateful processor operation= with invalid timeMode=" + ], + "sqlState" : "42802" + }, + "STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : { + "message" : [ + "State variable with name has already been defined in the StatefulProcessor." + ], + "sqlState" : "42802" + }, + "STATEFUL_PROCESSOR_INCORRECT_TIME_MODE_TO_ASSIGN_TTL" : { + "message" : [ + "Cannot use TTL for state= in timeMode=, use TimeMode.ProcessingTime() instead." + ], + "sqlState" : "42802" + }, + "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE" : { + "message" : [ + "TTL duration must be greater than zero for State store operation= on state=." + ], + "sqlState" : "42802" + }, + "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE" : { + "message" : [ + "Unknown time mode . Accepted timeMode modes are 'none', 'processingTime', 'eventTime'" + ], + "sqlState" : "42802" + }, "STATE_REPARTITION_INVALID_CHECKPOINT" : { "message" : [ "The provided checkpoint location '' is in an invalid state." @@ -5245,7 +5281,7 @@ "LAST_BATCH_FAILED" : { "message" : [ "The last batch ID didn't finish successfully. Please make sure the streaming query finishes successfully, before repartitioning.", - "If using ProcessingTime trigger, you can use AvailableNow trigger instead, which will make sure the query terminates successfully by itself.", + "If using ProcessingTime trigger, you can use AvailableNow trigger instead, which will make sure the query terminates successfully by itself.", "If you want to skip this check, set enforceExactlyOnceSink parameter in repartition to false.", "But this can cause duplicate output records from the failed batch when using exactly-once sinks." ] @@ -5308,42 +5344,6 @@ }, "sqlState" : "42616" }, - "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_HANDLE_STATE" : { - "message" : [ - "Failed to perform stateful processor operation= with invalid handle state=." - ], - "sqlState" : "42802" - }, - "STATEFUL_PROCESSOR_CANNOT_PERFORM_OPERATION_WITH_INVALID_TIME_MODE" : { - "message" : [ - "Failed to perform stateful processor operation= with invalid timeMode=" - ], - "sqlState" : "42802" - }, - "STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : { - "message" : [ - "State variable with name has already been defined in the StatefulProcessor." - ], - "sqlState" : "42802" - }, - "STATEFUL_PROCESSOR_INCORRECT_TIME_MODE_TO_ASSIGN_TTL" : { - "message" : [ - "Cannot use TTL for state= in timeMode=, use TimeMode.ProcessingTime() instead." - ], - "sqlState" : "42802" - }, - "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE" : { - "message" : [ - "TTL duration must be greater than zero for State store operation= on state=." - ], - "sqlState" : "42802" - }, - "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE" : { - "message" : [ - "Unknown time mode . Accepted timeMode modes are 'none', 'processingTime', 'eventTime'" - ], - "sqlState" : "42802" - }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." From fba27ba629801c5a0e822ab3177cd6d3802bacfc Mon Sep 17 00:00:00 2001 From: micheal-o Date: Fri, 14 Nov 2025 11:37:05 -0800 Subject: [PATCH 3/9] nit --- .../apache/spark/sql/classic/StreamingCheckpointManager.scala | 4 ++-- ...nBatchRunner.scala => OfflineStateRepartitionRunner.scala} | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/{OfflineStateRepartitionBatchRunner.scala => OfflineStateRepartitionRunner.scala} (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala index ebe0f5b572dc..b7c4ff362b22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.classic import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionBatchRunner, OfflineStateRepartitionErrors} +import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionErrors, OfflineStateRepartitionRunner} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming @@ -44,7 +44,7 @@ private[spark] class StreamingCheckpointManager( throw OfflineStateRepartitionErrors.parameterIsNotGreaterThanZeroError("numPartitions") } - val runner = new OfflineStateRepartitionBatchRunner( + val runner = new OfflineStateRepartitionRunner( sparkSession, checkpointLocation, numPartitions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index ec4df3ed06b3..d55764829087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionBatchRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpoint import org.apache.spark.sql.internal.SQLConf /** - * Offline batch runner for repartitioning the state stores used by a streaming query. + * Runs repartitioning for the state stores used by a streaming query. * * This class handles the process of creating a new microbatch, repartitioning state data * across new partitions, and committing the changes to the checkpoint i.e. @@ -42,7 +42,7 @@ import org.apache.spark.sql.internal.SQLConf * @param enforceExactlyOnceSink if we shouldn't allow skipping failed batches, * to avoid duplicates in exactly once sinks. */ -class OfflineStateRepartitionBatchRunner( +class OfflineStateRepartitionRunner( sparkSession: SparkSession, checkpointLocation: String, numPartitions: Int, From 60f7a3fb3004bc14816901bfb7982d9fdecf6a98 Mon Sep 17 00:00:00 2001 From: micheal-o Date: Fri, 14 Nov 2025 16:37:43 -0800 Subject: [PATCH 4/9] lint --- .../StreamingCheckpointManager.scala | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala index dfa6437b59e1..22aabcc00f38 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala @@ -24,24 +24,27 @@ private[spark] abstract class StreamingCheckpointManager { /** * Repartition the stateful streaming operators state in the streaming checkpoint to have - * `numPartitions` partitions. The streaming query MUST not be running. - * If `numPartitions` is the same as the current number of partitions, this is a no-op, - * and an exception will be thrown. + * `numPartitions` partitions. The streaming query MUST not be running. If `numPartitions` is + * the same as the current number of partitions, this is a no-op, and an exception will be + * thrown. * * This produces a new microbatch in the checkpoint that contains the repartitioned state i.e. - * if the last streaming batch was batch `N`, this will create batch `N+1` with the repartitioned - * state. Note that this new batch doesn't read input data from sources, it only represents the - * repartition operation. The next time the streaming query is started, it will pick up from - * this new batch. + * if the last streaming batch was batch `N`, this will create batch `N+1` with the + * repartitioned state. Note that this new batch doesn't read input data from sources, it only + * represents the repartition operation. The next time the streaming query is started, it will + * pick up from this new batch. * * This will return only when the repartitioning is complete or fails. * - * @note This operation should only be performed after the streaming query has been stopped. - * @param checkpointLocation The checkpoint location of the streaming query, should be the - * `checkpointLocation` option on the DataStreamWriter. - * @param numPartitions the target number of state partitions. - * @param enforceExactlyOnceSink if we shouldn't allow skipping failed batches, - * to avoid duplicates in exactly once sinks. + * @note + * This operation should only be performed after the streaming query has been stopped. + * @param checkpointLocation + * The checkpoint location of the streaming query, should be the `checkpointLocation` option + * on the DataStreamWriter. + * @param numPartitions + * the target number of state partitions. + * @param enforceExactlyOnceSink + * if we shouldn't allow skipping failed batches, to avoid duplicates in exactly once sinks. */ private[spark] def repartition( checkpointLocation: String, From 51a7458685e76dc6e662969c0c26c4b5edad6912 Mon Sep 17 00:00:00 2001 From: micheal-o Date: Fri, 14 Nov 2025 16:53:34 -0800 Subject: [PATCH 5/9] move utils --- .../sql/execution/datasources/v2/state/StateDataSource.scala | 2 +- .../streaming/state/OfflineStateRepartitionRunner.scala | 2 +- .../sql/execution/streaming/{ => utils}/StreamingUtils.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => utils}/StreamingUtils.scala (95%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index af478aa5c97c..6af418e1ddc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{J import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.StreamingUtils import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} @@ -43,6 +42,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index d55764829087..adc6d812cc13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -21,9 +21,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.StreamingUtils import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata +import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/utils/StreamingUtils.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/utils/StreamingUtils.scala index 818c9e99e1d6..d2654ac943b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/utils/StreamingUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.utils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path From 7eefad32c0e2f075fdbcaec89cbfdf0da9bb9e58 Mon Sep 17 00:00:00 2001 From: micheal-o Date: Fri, 14 Nov 2025 17:54:22 -0800 Subject: [PATCH 6/9] extra log --- .../org/apache/spark/internal/LogKeys.java | 1 + .../state/OfflineStateRepartitionRunner.scala | 41 +++++++++++++------ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 48bc6f201bc7..f65c64e5cc6f 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -192,6 +192,7 @@ public enum LogKeys implements LogKey { END_INDEX, END_POINT, END_VERSION, + ENFORCE_EXACTLY_ONCE, ENGINE, EPOCH, ERROR, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index adc6d812cc13..2456b2c9b73b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, O import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils /** * Runs repartitioning for the state stores used by a streaming query. @@ -67,24 +68,40 @@ class OfflineStateRepartitionRunner( def run(): Long = { logInfo(log"Starting offline state repartitioning for " + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + - log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}") + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}, " + + log"enforceExactlyOnceSink=${MDC(ENFORCE_EXACTLY_ONCE, enforceExactlyOnceSink)}") - val lastCommittedBatchId = getLastCommittedBatchId() - val lastBatchId = getLastBatchId() + try { + val (repartitionBatchId, durationMs) = Utils.timeTakenMs { + val lastCommittedBatchId = getLastCommittedBatchId() + val lastBatchId = getLastBatchId() - val newBatchId = createNewBatchIfNeeded(lastBatchId, lastCommittedBatchId) + val newBatchId = createNewBatchIfNeeded(lastBatchId, lastCommittedBatchId) - // todo: Do the repartitioning here, in subsequent PR + // todo(SPARK-54365): Do the repartitioning here, in subsequent PR - // todo: update operator metadata in subsequent PR. + // todo(SPARK-54365): update operator metadata in subsequent PR. - // Commit the repartition batch - commitBatch(newBatchId, lastCommittedBatchId) - - logInfo(log"Completed state repartitioning with new " + - log"batchId=${MDC(BATCH_ID, newBatchId)}") + // Commit the repartition batch + commitBatch(newBatchId, lastCommittedBatchId) + newBatchId + } - newBatchId + logInfo(log"Completed state repartitioning for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}, " + + log"enforceExactlyOnceSink=${MDC(ENFORCE_EXACTLY_ONCE, enforceExactlyOnceSink)}, " + + log"repartitionBatchId=${MDC(BATCH_ID, repartitionBatchId)}, " + + log"durationMs=${MDC(DURATION, durationMs)}") + + repartitionBatchId + } catch { + case e: Throwable => + logError(log"State repartitioning failed for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}", e) + throw e + } } private def getLastCommittedBatchId(): Long = { From 9641d7ec4bbdd64fa5100f04b2488b35465dc531 Mon Sep 17 00:00:00 2001 From: micheal-o Date: Mon, 17 Nov 2025 14:45:00 -0800 Subject: [PATCH 7/9] spark version and multi skip --- .../src/main/resources/error/error-conditions.json | 4 ++-- .../state/OfflineStateRepartitionSuite.scala | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 6b1e7f9c180a..032d741e6d3a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5288,7 +5288,7 @@ }, "MISSING_OFFSET_SEQ_METADATA" : { "message" : [ - "The OffsetSeq (v) metadata is missing for batch ID . Please make sure the checkpoint is from a supported Spark version." + "The OffsetSeq (v) metadata is missing for batch ID . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." ] }, "NO_BATCH_FOUND" : { @@ -5315,7 +5315,7 @@ }, "UNSUPPORTED_OFFSET_SEQ_VERSION" : { "message" : [ - "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version." + "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." ] } }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala index eadc26e21e21..86b5502b652e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -157,13 +157,16 @@ class OfflineStateRepartitionSuite extends StreamTest { test("Query last batch failed before repartitioning") { withTempDir { dir => val originalPartitions = 3 - // Run the query twice to produce two batches val input = MemoryStream[Int] - val firstBatchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) - val lastBatchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) + // Run 3 batches + val firstBatchId = 0 + val lastBatchId = firstBatchId + 2 + (firstBatchId to lastBatchId).foreach { _ => + runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) + } val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) - // lets delete the last batch commit to simulate last query batch failed + // Lets keep only the first commit to simulate multiple failed batches checkpointMetadata.commitLog.purgeAfter(firstBatchId) // Now repartitioning should fail @@ -187,7 +190,7 @@ class OfflineStateRepartitionSuite extends StreamTest { checkpointMetadata, dir.getAbsolutePath, originalPartitions + 1, - // Repartition should be based on the first batch, since we skipped the last one + // Repartition should be based on the first batch, since we skipped the others baseBatchId = Some(firstBatchId)) } } From 11e3414ccbb79c75bed97f07bd8572c87dcd70bb Mon Sep 17 00:00:00 2001 From: micheal-o Date: Mon, 17 Nov 2025 19:29:25 -0800 Subject: [PATCH 8/9] nit --- .../apache/spark/sql/streaming/StreamingCheckpointManager.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala index 22aabcc00f38..31dcf84567a2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala @@ -38,6 +38,7 @@ private[spark] abstract class StreamingCheckpointManager { * * @note * This operation should only be performed after the streaming query has been stopped. + * If not, can lead to undefined behavior or checkpoint corruption. * @param checkpointLocation * The checkpoint location of the streaming query, should be the `checkpointLocation` option * on the DataStreamWriter. From 58f28493c00068a37f73da64f63ebb2b8a25da10 Mon Sep 17 00:00:00 2001 From: micheal-o Date: Tue, 18 Nov 2025 10:43:43 -0800 Subject: [PATCH 9/9] lint --- .../spark/sql/streaming/StreamingCheckpointManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala index 31dcf84567a2..7bb6fda4818c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala @@ -37,8 +37,8 @@ private[spark] abstract class StreamingCheckpointManager { * This will return only when the repartitioning is complete or fails. * * @note - * This operation should only be performed after the streaming query has been stopped. - * If not, can lead to undefined behavior or checkpoint corruption. + * This operation should only be performed after the streaming query has been stopped. If not, + * can lead to undefined behavior or checkpoint corruption. * @param checkpointLocation * The checkpoint location of the streaming query, should be the `checkpointLocation` option * on the DataStreamWriter.