diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 48bc6f201bc7..f65c64e5cc6f 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -192,6 +192,7 @@ public enum LogKeys implements LogKey { END_INDEX, END_POINT, END_VERSION, + ENFORCE_EXACTLY_ONCE, ENGINE, EPOCH, ERROR, diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f7eb1e63d7bd..032d741e6d3a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5265,6 +5265,85 @@ ], "sqlState" : "42802" }, + "STATE_REPARTITION_INVALID_CHECKPOINT" : { + "message" : [ + "The provided checkpoint location '' is in an invalid state." + ], + "subClass" : { + "LAST_BATCH_ABANDONED_REPARTITION" : { + "message" : [ + "The last batch ID is a repartition batch with shuffle partitions and didn't finish successfully.", + "You're now requesting to repartition to shuffle partitions.", + "Please retry with the same number of shuffle partitions as the previous attempt.", + "Once that completes successfully, you can repartition to another number of shuffle partitions." + ] + }, + "LAST_BATCH_FAILED" : { + "message" : [ + "The last batch ID didn't finish successfully. Please make sure the streaming query finishes successfully, before repartitioning.", + "If using ProcessingTime trigger, you can use AvailableNow trigger instead, which will make sure the query terminates successfully by itself.", + "If you want to skip this check, set enforceExactlyOnceSink parameter in repartition to false.", + "But this can cause duplicate output records from the failed batch when using exactly-once sinks." + ] + }, + "MISSING_OFFSET_SEQ_METADATA" : { + "message" : [ + "The OffsetSeq (v) metadata is missing for batch ID . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." + ] + }, + "NO_BATCH_FOUND" : { + "message" : [ + "No microbatch has been recorded in the checkpoint location. Make sure the streaming query has successfully completed at least one microbatch before repartitioning." + ] + }, + "NO_COMMITTED_BATCH" : { + "message" : [ + "There is no committed microbatch. Make sure the streaming query has successfully completed at least one microbatch before repartitioning." + ] + }, + "OFFSET_SEQ_NOT_FOUND" : { + "message" : [ + "Offset sequence entry for batch ID not found. You might have set a very low value for", + "'spark.sql.streaming.minBatchesToRetain' config during the streaming query execution or you deleted files in the checkpoint location." + ] + }, + "SHUFFLE_PARTITIONS_ALREADY_MATCH" : { + "message" : [ + "The number of shuffle partitions in the last committed batch (id=) is the same as the requested partitions.", + "Hence, already has the requested number of partitions, so no-op." + ] + }, + "UNSUPPORTED_OFFSET_SEQ_VERSION" : { + "message" : [ + "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." + ] + } + }, + "sqlState" : "55019" + }, + "STATE_REPARTITION_INVALID_PARAMETER" : { + "message" : [ + "The repartition parameter is invalid:" + ], + "subClass" : { + "IS_EMPTY" : { + "message" : [ + "cannot be empty." + ] + }, + "IS_NOT_GREATER_THAN_ZERO" : { + "message" : [ + "must be greater than zero." + ] + }, + "IS_NULL" : { + "message" : [ + "cannot be null." + ] + } + }, + "sqlState" : "42616" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala new file mode 100644 index 000000000000..7bb6fda4818c --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingCheckpointManager.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +/** + * A class to manage operations on streaming query checkpoints. + */ +private[spark] abstract class StreamingCheckpointManager { + + /** + * Repartition the stateful streaming operators state in the streaming checkpoint to have + * `numPartitions` partitions. The streaming query MUST not be running. If `numPartitions` is + * the same as the current number of partitions, this is a no-op, and an exception will be + * thrown. + * + * This produces a new microbatch in the checkpoint that contains the repartitioned state i.e. + * if the last streaming batch was batch `N`, this will create batch `N+1` with the + * repartitioned state. Note that this new batch doesn't read input data from sources, it only + * represents the repartition operation. The next time the streaming query is started, it will + * pick up from this new batch. + * + * This will return only when the repartitioning is complete or fails. + * + * @note + * This operation should only be performed after the streaming query has been stopped. If not, + * can lead to undefined behavior or checkpoint corruption. + * @param checkpointLocation + * The checkpoint location of the streaming query, should be the `checkpointLocation` option + * on the DataStreamWriter. + * @param numPartitions + * the target number of state partitions. + * @param enforceExactlyOnceSink + * if we shouldn't allow skipping failed batches, to avoid duplicates in exactly once sinks. + */ + private[spark] def repartition( + checkpointLocation: String, + numPartitions: Int, + enforceExactlyOnceSink: Boolean = true): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala index 18a84d8c4299..711a64b4589a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala @@ -145,6 +145,12 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) */ def streams: StreamingQueryManager = sparkSession.streams + /** + * Returns a `StreamingCheckpointManager` that allows managing any streaming checkpoint. + */ + private[spark] def streamingCheckpointManager: StreamingCheckpointManager = + sparkSession.streamingCheckpointManager + /** @inheritdoc */ override def sparkContext: SparkContext = super.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index f7876d9a023b..1ac9941c07ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -223,6 +223,8 @@ class SparkSession private( @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager + private[spark] def streamingCheckpointManager = sessionState.streamingCheckpointManager + /** * Returns an `ArtifactManager` that supports adding, managing and using session-scoped artifacts * (jars, classfiles, etc). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala new file mode 100644 index 000000000000..b7c4ff362b22 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingCheckpointManager.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionErrors, OfflineStateRepartitionRunner} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming + +/** @inheritdoc */ +private[spark] class StreamingCheckpointManager( + sparkSession: SparkSession, + sqlConf: SQLConf) extends streaming.StreamingCheckpointManager with Logging { + + /** @inheritdoc */ + override private[spark] def repartition( + checkpointLocation: String, + numPartitions: Int, + enforceExactlyOnceSink: Boolean = true): Unit = { + checkpointLocation match { + case null => + throw OfflineStateRepartitionErrors.parameterIsNullError("checkpointLocation") + case "" => + throw OfflineStateRepartitionErrors.parameterIsEmptyError("checkpointLocation") + case _ => // Valid case, no action needed + } + + if (numPartitions <= 0) { + throw OfflineStateRepartitionErrors.parameterIsNotGreaterThanZeroError("numPartitions") + } + + val runner = new OfflineStateRepartitionRunner( + sparkSession, + checkpointLocation, + numPartitions, + enforceExactlyOnceSink + ) + runner.run() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index fea4d345b8d0..6af418e1ddc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode import org.apache.spark.sql.types.StructType @@ -481,7 +482,8 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } - val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation) + val resolvedCpLocation = StreamingUtils.resolvedCheckpointLocation( + hadoopConf, checkpointLocation) var batchId = Option(options.get(BATCH_ID)).map(_.toLong) @@ -617,14 +619,6 @@ object StateSourceOptions extends DataSourceOptions { startOperatorStateUniqueIds, endOperatorStateUniqueIds) } - private def resolvedCheckpointLocation( - hadoopConf: Configuration, - checkpointLocation: String): String = { - val checkpointPath = new Path(checkpointLocation) - val fs = checkpointPath.getFileSystem(hadoopConf) - checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString - } - private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog commitLog.getLatest() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala new file mode 100644 index 000000000000..95b273826877 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} + +/** + * Errors thrown by Offline state repartitioning. + */ +object OfflineStateRepartitionErrors { + def parameterIsEmptyError(parameter: String): StateRepartitionInvalidParameterError = { + new StateRepartitionParameterIsEmptyError(parameter) + } + + def parameterIsNotGreaterThanZeroError( + parameter: String): StateRepartitionInvalidParameterError = { + new StateRepartitionParameterIsNotGreaterThanZeroError(parameter) + } + + def parameterIsNullError(parameter: String): StateRepartitionInvalidParameterError = { + new StateRepartitionParameterIsNullError(parameter) + } + + def lastBatchAbandonedRepartitionError( + checkpointLocation: String, + lastBatchId: Long, + lastBatchShufflePartitions: Int, + numPartitions: Int): StateRepartitionInvalidCheckpointError = { + new StateRepartitionLastBatchAbandonedRepartitionError( + checkpointLocation, lastBatchId, lastBatchShufflePartitions, numPartitions) + } + + def lastBatchFailedError( + checkpointLocation: String, + lastBatchId: Long): StateRepartitionInvalidCheckpointError = { + new StateRepartitionLastBatchFailedError(checkpointLocation, lastBatchId) + } + + def missingOffsetSeqMetadataError( + checkpointLocation: String, + version: Int, + batchId: Long): StateRepartitionInvalidCheckpointError = { + new StateRepartitionMissingOffsetSeqMetadataError(checkpointLocation, version, batchId) + } + + def noBatchFoundError(checkpointLocation: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionNoBatchFoundError(checkpointLocation) + } + + def noCommittedBatchError(checkpointLocation: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionNoCommittedBatchError(checkpointLocation) + } + + def offsetSeqNotFoundError( + checkpointLocation: String, + batchId: Long): StateRepartitionInvalidCheckpointError = { + new StateRepartitionOffsetSeqNotFoundError(checkpointLocation, batchId) + } + + def shufflePartitionsAlreadyMatchError( + checkpointLocation: String, + batchId: Long, + numPartitions: Int): StateRepartitionInvalidCheckpointError = { + new StateRepartitionShufflePartitionsAlreadyMatchError( + checkpointLocation, batchId, numPartitions) + } + + def unsupportedOffsetSeqVersionError( + checkpointLocation: String, + version: Int): StateRepartitionInvalidCheckpointError = { + new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation, version) + } +} + +/** + * Base class for exceptions thrown when an invalid parameter is passed + * into the repartition operation. + */ +abstract class StateRepartitionInvalidParameterError( + parameter: String, + subClass: String, + messageParameters: Map[String, String] = Map.empty, + cause: Throwable = null) + extends SparkIllegalArgumentException( + errorClass = s"STATE_REPARTITION_INVALID_PARAMETER.$subClass", + messageParameters = Map("parameter" -> parameter) ++ messageParameters, + cause = cause) + +class StateRepartitionParameterIsEmptyError(parameter: String) + extends StateRepartitionInvalidParameterError( + parameter, + subClass = "IS_EMPTY") + +class StateRepartitionParameterIsNotGreaterThanZeroError(parameter: String) + extends StateRepartitionInvalidParameterError( + parameter, + subClass = "IS_NOT_GREATER_THAN_ZERO") + +class StateRepartitionParameterIsNullError(parameter: String) + extends StateRepartitionInvalidParameterError( + parameter, + subClass = "IS_NULL") + +/** + * Base class for exceptions thrown when the checkpoint location is in an invalid state + * for repartitioning. + */ +abstract class StateRepartitionInvalidCheckpointError( + checkpointLocation: String, + subClass: String, + messageParameters: Map[String, String], + cause: Throwable = null) + extends SparkIllegalStateException( + errorClass = s"STATE_REPARTITION_INVALID_CHECKPOINT.$subClass", + messageParameters = Map("checkpointLocation" -> checkpointLocation) ++ messageParameters, + cause = cause) + +class StateRepartitionLastBatchAbandonedRepartitionError( + checkpointLocation: String, + lastBatchId: Long, + lastBatchShufflePartitions: Int, + numPartitions: Int) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "LAST_BATCH_ABANDONED_REPARTITION", + messageParameters = Map( + "lastBatchId" -> lastBatchId.toString, + "lastBatchShufflePartitions" -> lastBatchShufflePartitions.toString, + "numPartitions" -> numPartitions.toString + )) + +class StateRepartitionLastBatchFailedError( + checkpointLocation: String, + lastBatchId: Long) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "LAST_BATCH_FAILED", + messageParameters = Map("lastBatchId" -> lastBatchId.toString)) + +class StateRepartitionMissingOffsetSeqMetadataError( + checkpointLocation: String, + version: Int, + batchId: Long) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "MISSING_OFFSET_SEQ_METADATA", + messageParameters = Map("version" -> version.toString, "batchId" -> batchId.toString)) + +class StateRepartitionNoBatchFoundError( + checkpointLocation: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "NO_BATCH_FOUND", + messageParameters = Map.empty) + +class StateRepartitionNoCommittedBatchError( + checkpointLocation: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "NO_COMMITTED_BATCH", + messageParameters = Map.empty) + +class StateRepartitionOffsetSeqNotFoundError( + checkpointLocation: String, + batchId: Long) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "OFFSET_SEQ_NOT_FOUND", + messageParameters = Map("batchId" -> batchId.toString)) + +class StateRepartitionShufflePartitionsAlreadyMatchError( + checkpointLocation: String, + batchId: Long, + numPartitions: Int) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "SHUFFLE_PARTITIONS_ALREADY_MATCH", + messageParameters = Map( + "batchId" -> batchId.toString, + "numPartitions" -> numPartitions.toString)) + +class StateRepartitionUnsupportedOffsetSeqVersionError( + checkpointLocation: String, + version: Int) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION", + messageParameters = Map("version" -> version.toString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala new file mode 100644 index 000000000000..2456b2c9b73b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata +import org.apache.spark.sql.execution.streaming.utils.StreamingUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Runs repartitioning for the state stores used by a streaming query. + * + * This class handles the process of creating a new microbatch, repartitioning state data + * across new partitions, and committing the changes to the checkpoint i.e. + * if the last streaming batch was batch `N`, this will create batch `N+1` with the repartitioned + * state. Note that this new batch doesn't read input data from sources, it only represents the + * repartition operation. The next time the streaming query is started, it will pick up from + * this new batch. + * + * @param sparkSession The active Spark session + * @param checkpointLocation The checkpoint location path + * @param numPartitions The new number of partitions to repartition to + * @param enforceExactlyOnceSink if we shouldn't allow skipping failed batches, + * to avoid duplicates in exactly once sinks. + */ +class OfflineStateRepartitionRunner( + sparkSession: SparkSession, + checkpointLocation: String, + numPartitions: Int, + enforceExactlyOnceSink: Boolean = true) extends Logging { + + import OfflineStateRepartitionUtils._ + + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + + private val resolvedCpLocation = StreamingUtils.resolvedCheckpointLocation( + hadoopConf, checkpointLocation) + + private val checkpointMetadata = new StreamingQueryCheckpointMetadata( + sparkSession, resolvedCpLocation) + + /** + * Runs a repartitioning batch and returns the batch ID. + * This will only return when the repartitioning is done. + * + * @return The repartition batch ID + */ + def run(): Long = { + logInfo(log"Starting offline state repartitioning for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}, " + + log"enforceExactlyOnceSink=${MDC(ENFORCE_EXACTLY_ONCE, enforceExactlyOnceSink)}") + + try { + val (repartitionBatchId, durationMs) = Utils.timeTakenMs { + val lastCommittedBatchId = getLastCommittedBatchId() + val lastBatchId = getLastBatchId() + + val newBatchId = createNewBatchIfNeeded(lastBatchId, lastCommittedBatchId) + + // todo(SPARK-54365): Do the repartitioning here, in subsequent PR + + // todo(SPARK-54365): update operator metadata in subsequent PR. + + // Commit the repartition batch + commitBatch(newBatchId, lastCommittedBatchId) + newBatchId + } + + logInfo(log"Completed state repartitioning for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}, " + + log"enforceExactlyOnceSink=${MDC(ENFORCE_EXACTLY_ONCE, enforceExactlyOnceSink)}, " + + log"repartitionBatchId=${MDC(BATCH_ID, repartitionBatchId)}, " + + log"durationMs=${MDC(DURATION, durationMs)}") + + repartitionBatchId + } catch { + case e: Throwable => + logError(log"State repartitioning failed for " + + log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, checkpointLocation)}, " + + log"numPartitions=${MDC(NUM_PARTITIONS, numPartitions)}", e) + throw e + } + } + + private def getLastCommittedBatchId(): Long = { + checkpointMetadata.commitLog.getLatestBatchId() match { + case Some(id) => id + // Needs at least 1 committed batch to repartition + case None => throw OfflineStateRepartitionErrors.noCommittedBatchError(checkpointLocation) + } + } + + private def getLastBatchId(): Long = { + checkpointMetadata.offsetLog.getLatestBatchId() match { + case Some(id) => id + case None => throw OfflineStateRepartitionErrors.noBatchFoundError(checkpointLocation) + } + } + + private def createNewBatchIfNeeded(lastBatchId: Long, lastCommittedBatchId: Long): Long = { + if (lastBatchId == lastCommittedBatchId) { + // Means there are no uncommitted batches. So start a new batch. + createNewBatchFromLastCommitted(lastBatchId, lastCommittedBatchId) + } else { + // Means there are uncommitted batches. + if (isRepartitionBatch(lastBatchId, checkpointMetadata.offsetLog, checkpointLocation)) { + // If it is a failed repartition batch, lets check if the shuffle partitions + // is the same as the requested. If same, then we can retry the batch. + val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get + val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + if (lastBatchShufflePartitions == numPartitions) { + // We can retry the repartition batch. + logInfo(log"The last batch is a failed repartition batch " + + log"(batchId=${MDC(BATCH_ID, lastBatchId)}). " + + log"Retrying it since it used the same number of shuffle partitions " + + log"as the requested ${MDC(NUM_PARTITIONS, numPartitions)}.") + lastBatchId + } else { + // Failed repartition should be retried with the same number of shuffle partitions. + // Once that completes successfully, then can repartition to another number + // of shuffle partitions. + throw OfflineStateRepartitionErrors.lastBatchAbandonedRepartitionError( + checkpointLocation, lastBatchId, lastBatchShufflePartitions, numPartitions) + } + } else { + if (enforceExactlyOnceSink) { + // We want the last batch to have committed successfully. + // Before proceeding with repartitioning, since repartitioning produces a new batch. + // If we skip the unsuccessful batch, this can cause duplicates in exactly-once sinks + // which uses the batchId to track already committed data. + throw OfflineStateRepartitionErrors.lastBatchFailedError(checkpointLocation, lastBatchId) + } else { + // We can skip the uncommitted batches. And repartition using the last committed + // batch state. Note that input data from the skipped failed batch will be reprocessed + // in the next query run. + skipUncommittedBatches(lastBatchId, lastCommittedBatchId) + // Now create a new batch + createNewBatchFromLastCommitted(lastBatchId, lastCommittedBatchId) + } + } + } + } + + private def skipUncommittedBatches(lastBatchId: Long, lastCommittedBatchId: Long): Unit = { + assert(lastBatchId > lastCommittedBatchId, + "Last batch ID must be greater than last committed batch ID") + + val fromBatchId = lastCommittedBatchId + 1 + for (batchId <- fromBatchId to lastBatchId) { + // write empty commit for these skipped batches + if (!checkpointMetadata.commitLog.add(batchId, CommitMetadata())) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(batchId) + } + } + + logInfo(log"Skipped uncommitted batches from batchId " + + log"${MDC(BATCH_ID, fromBatchId)} to ${MDC(BATCH_ID, lastBatchId)}") + } + + /** + * Creates a new offset log entry for the repartition batch using the OffsetSeq + * of the last committed batch. But with a new number of partitions. + */ + private def createNewBatchFromLastCommitted( + lastBatchId: Long, + lastCommittedBatchId: Long): Long = { + val newBatchId = lastBatchId + 1 + // We want to repartition the state as of the last committed batch. + val lastCommittedOffsetSeq = checkpointMetadata.offsetLog.get(lastCommittedBatchId) + .getOrElse(throw OfflineStateRepartitionErrors + .offsetSeqNotFoundError(checkpointLocation, lastCommittedBatchId)) + + // Missing offset metadata not supported + val lastCommittedMetadata = lastCommittedOffsetSeq.metadata.getOrElse( + throw OfflineStateRepartitionErrors.missingOffsetSeqMetadataError( + checkpointLocation, version = 1, batchId = lastCommittedBatchId) + ) + + // No-op if the number of shuffle partitions in last commit is the same as the requested. + if (getShufflePartitions(lastCommittedMetadata).get == numPartitions) { + throw OfflineStateRepartitionErrors.shufflePartitionsAlreadyMatchError( + checkpointLocation, lastCommittedBatchId, numPartitions) + } + + // Create a new OffsetSeq from the last committed but with an update num shuffle partitions + val newOffsetSeq = lastCommittedOffsetSeq match { + case v1: OffsetSeq => + val metadata = v1.metadata.get + v1.copy(metadata = Some(metadata.copy( + conf = metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString)))) + case _ => throw OfflineStateRepartitionErrors.unsupportedOffsetSeqVersionError( + checkpointLocation, version = -1) + } + + // Will fail if there is a concurrent operation on going + if (!checkpointMetadata.offsetLog.add(newBatchId, newOffsetSeq)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId) + } + + logInfo(log"Created new offset log entry for repartition batch. " + + log"batchId=${MDC(BATCH_ID, newBatchId)}") + + newBatchId + } + + private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit = { + val latestCommit = checkpointMetadata.commitLog.get(lastCommittedBatchId).get + + // todo: For checkpoint v2, we need to update the stateUniqueIds based on the + // newly created state commit. Will be done in subsequent PR. + if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) { + throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId) + } + } +} + +object OfflineStateRepartitionUtils { + def isRepartitionBatch( + batchId: Long, offsetLog: OffsetSeqLog, checkpointLocation: String): Boolean = { + assert(batchId >= 0, "Batch ID must be non-negative") + batchId match { + // first batch can never be a repartition batch since we require at least one committed batch + case 0 => false + case _ => + // A repartition batch is a batch where the number of shuffle partitions changed + // compared to the previous batch. + val batch = offsetLog.get(batchId).getOrElse(throw OfflineStateRepartitionErrors + .offsetSeqNotFoundError(checkpointLocation, batchId)) + val prevBatchId = batchId - 1 + val previousBatch = offsetLog.get(prevBatchId).getOrElse( + throw OfflineStateRepartitionErrors + .offsetSeqNotFoundError(checkpointLocation, prevBatchId)) + + val batchMetadata = batch.metadata.getOrElse(throw OfflineStateRepartitionErrors + .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = batchId)) + val shufflePartitions = getShufflePartitions(batchMetadata).get + + val previousBatchMetadata = previousBatch.metadata.getOrElse( + throw OfflineStateRepartitionErrors + .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = prevBatchId)) + val previousShufflePartitions = getShufflePartitions(previousBatchMetadata).get + + previousShufflePartitions != shufflePartitions + } + } + + def getShufflePartitions(metadata: OffsetSeqMetadata): Option[Int] = { + metadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).map(_.toInt) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/utils/StreamingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/utils/StreamingUtils.scala new file mode 100644 index 000000000000..d2654ac943b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/utils/StreamingUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming.utils + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +object StreamingUtils { + def resolvedCheckpointLocation(hadoopConf: Configuration, checkpointLocation: String): String = { + val checkpointPath = new Path(checkpointLocation) + val fs = checkpointPath.getFileSystem(hadoopConf) + checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index c967497b660c..ff6e58c2b2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.classic.{SparkSession, Strategy, StreamingQueryManager, UDFRegistration} +import org.apache.spark.sql.classic.{SparkSession, Strategy, StreamingCheckpointManager, StreamingQueryManager, UDFRegistration} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} @@ -418,6 +418,12 @@ abstract class BaseSessionStateBuilder( protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session, conf) + /** + * Interface to manage streaming query checkpoints. + */ + private[spark] def streamingCheckpointManager: StreamingCheckpointManager = + new StreamingCheckpointManager(session, conf) + /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -465,6 +471,7 @@ abstract class BaseSessionStateBuilder( () => optimizer, planner, () => streamingQueryManager, + () => streamingCheckpointManager, listenerManager, () => resourceLoader, createQueryExecution, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 440148989ffb..2e921d0054e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.classic.{SparkSession, StreamingQueryManager, UDFRegistration} +import org.apache.spark.sql.classic.{SparkSession, StreamingCheckpointManager, StreamingQueryManager, UDFRegistration} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder @@ -84,6 +84,7 @@ private[sql] class SessionState( optimizerBuilder: () => Optimizer, val planner: SparkPlanner, val streamingQueryManagerBuilder: () => StreamingQueryManager, + val streamingCheckpointManagerBuilder: () => StreamingCheckpointManager, val listenerManager: ExecutionListenerManager, resourceLoaderBuilder: () => SessionResourceLoader, createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) => QueryExecution, @@ -106,6 +107,9 @@ private[sql] class SessionState( // when connecting to ThriftServer. lazy val streamingQueryManager: StreamingQueryManager = streamingQueryManagerBuilder() + private[spark] lazy val streamingCheckpointManager: StreamingCheckpointManager = + streamingCheckpointManagerBuilder() + lazy val artifactManager: ArtifactManager = artifactManagerBuilder() def catalogManager: CatalogManager = analyzer.catalogManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala new file mode 100644 index 000000000000..86b5502b652e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryCheckpointMetadata} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ + +/** + * Test for offline state repartitioning. This tests that repartition behaves as expected + * for different scenarios. + */ +class OfflineStateRepartitionSuite extends StreamTest { + import testImplicits._ + import OfflineStateRepartitionUtils._ + + test("Fail if empty checkpoint directory") { + withTempDir { dir => + val ex = intercept[StateRepartitionNoCommittedBatchError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 5) + } + + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.NO_COMMITTED_BATCH", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath + ) + ) + } + } + + test("Fail if no batch found in checkpoint directory") { + withTempDir { dir => + // Write commit log but no offset log. + val commitLog = new CommitLog(spark, dir.getCanonicalPath + "/commits") + commitLog.add(0, CommitMetadata()) + + val ex = intercept[StateRepartitionNoBatchFoundError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 5) + } + + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.NO_BATCH_FOUND", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath + ) + ) + } + } + + test("Fails if repartition parameter is invalid") { + val ex1 = intercept[StateRepartitionParameterIsNullError] { + spark.streamingCheckpointManager.repartition(null, 5) + } + + checkError( + ex1, + condition = "STATE_REPARTITION_INVALID_PARAMETER.IS_NULL", + parameters = Map("parameter" -> "checkpointLocation") + ) + + val ex2 = intercept[StateRepartitionParameterIsEmptyError] { + spark.streamingCheckpointManager.repartition("", 5) + } + + checkError( + ex2, + condition = "STATE_REPARTITION_INVALID_PARAMETER.IS_EMPTY", + parameters = Map("parameter" -> "checkpointLocation") + ) + + val ex3 = intercept[StateRepartitionParameterIsNotGreaterThanZeroError] { + spark.streamingCheckpointManager.repartition("test", 0) + } + + checkError( + ex3, + condition = "STATE_REPARTITION_INVALID_PARAMETER.IS_NOT_GREATER_THAN_ZERO", + parameters = Map("parameter" -> "numPartitions") + ) + } + + test("Repartition: success, failure, retry") { + withTempDir { dir => + val originalPartitions = 3 + val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath) + val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + // Shouldn't be seen as a repartition batch + assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, dir.getAbsolutePath)) + + // Trying to repartition to the same number should fail + val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions) + } + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath, + "batchId" -> batchId.toString, + "numPartitions" -> originalPartitions.toString + ) + ) + + // Trying to repartition to a different number should succeed + val newPartitions = originalPartitions + 1 + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions) + val repartitionBatchId = batchId + 1 + verifyRepartitionBatch( + repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, newPartitions) + + // Now delete the repartition commit to simulate a failed repartition attempt. + // This will delete all the commits after the batchId. + checkpointMetadata.commitLog.purgeAfter(batchId) + + // Try to repartition with a different numPartitions should fail, + // since it will see an uncommitted repartition batch with a different numPartitions. + val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions + 1) + } + checkError( + ex2, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath, + "lastBatchId" -> repartitionBatchId.toString, + "lastBatchShufflePartitions" -> newPartitions.toString, + "numPartitions" -> (newPartitions + 1).toString + ) + ) + + // Retrying with the same numPartitions should work + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, newPartitions) + verifyRepartitionBatch( + repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, newPartitions) + } + } + + test("Query last batch failed before repartitioning") { + withTempDir { dir => + val originalPartitions = 3 + val input = MemoryStream[Int] + // Run 3 batches + val firstBatchId = 0 + val lastBatchId = firstBatchId + 2 + (firstBatchId to lastBatchId).foreach { _ => + runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input) + } + val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + + // Lets keep only the first commit to simulate multiple failed batches + checkpointMetadata.commitLog.purgeAfter(firstBatchId) + + // Now repartitioning should fail + val ex = intercept[StateRepartitionLastBatchFailedError] { + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions + 1) + } + checkError( + ex, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_FAILED", + parameters = Map( + "checkpointLocation" -> dir.getAbsolutePath, + "lastBatchId" -> lastBatchId.toString + ) + ) + + // Setting enforceExactlyOnceSink to false should allow repartitioning + spark.streamingCheckpointManager.repartition( + dir.getAbsolutePath, originalPartitions + 1, enforceExactlyOnceSink = false) + verifyRepartitionBatch( + lastBatchId + 1, + checkpointMetadata, + dir.getAbsolutePath, + originalPartitions + 1, + // Repartition should be based on the first batch, since we skipped the others + baseBatchId = Some(firstBatchId)) + } + } + + test("Consecutive repartition") { + withTempDir { dir => + val originalPartitions = 3 + val batchId = runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath) + + val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, dir.getAbsolutePath) + + // decrease + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions - 1) + verifyRepartitionBatch( + batchId + 1, + checkpointMetadata, + dir.getAbsolutePath, + originalPartitions - 1 + ) + + // increase + spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, originalPartitions + 1) + verifyRepartitionBatch( + batchId + 2, + checkpointMetadata, + dir.getAbsolutePath, + originalPartitions + 1 + ) + } + } + + private def runSimpleStreamQuery( + numPartitions: Int, + checkpointLocation: String, + input: MemoryStream[Int] = MemoryStream[Int]): Long = { + val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) + + var committedBatchId: Long = -1 + testStream(input.toDF().groupBy().count(), outputMode = OutputMode.Update)( + StartStream(checkpointLocation = checkpointLocation, additionalConfs = conf), + AddData(input, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + committedBatchId = Option(query.lastProgress).map(_.batchId).getOrElse(-1) + } + ) + + assert(committedBatchId >= 0, "No batch was committed in the streaming query") + committedBatchId + } + + private def verifyRepartitionBatch( + batchId: Long, + checkpointMetadata: StreamingQueryCheckpointMetadata, + checkpointLocation: String, + expectedShufflePartitions: Int, + baseBatchId: Option[Long] = None): Unit = { + // Should be seen as a repartition batch + assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog, checkpointLocation)) + + // Verify the repartition batch + val lastBatchId = checkpointMetadata.offsetLog.getLatestBatchId().get + assert(lastBatchId == batchId) + + val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get + val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + assert(lastBatchShufflePartitions == expectedShufflePartitions) + + // Verify the commit log + val lastCommitId = checkpointMetadata.commitLog.getLatestBatchId().get + assert(lastCommitId == batchId) + + // verify that the offset seq is the same between repartition batch and + // the batch the repartition is based on except for the shuffle partitions. + // When failed batches are skipped, then repartition can be based + // on an older batch and not batchId - 1. + val previousBatchId = baseBatchId.getOrElse(batchId - 1) + val previousBatch = checkpointMetadata.offsetLog.get(previousBatchId).get + + // Verify offsets are identical + assert(lastBatch.offsets == previousBatch.offsets, + s"Offsets should be identical between batch $previousBatchId and $batchId") + + // Verify metadata is the same except for shuffle partitions config + (lastBatch.metadata, previousBatch.metadata) match { + case (Some(lastMetadata), Some(previousMetadata)) => + // Check watermark and timestamp are the same + assert(lastMetadata.batchWatermarkMs == previousMetadata.batchWatermarkMs, + "Batch watermark should be the same") + assert(lastMetadata.batchTimestampMs == previousMetadata.batchTimestampMs, + "Batch timestamp should be the same") + + // Check all configs are the same except shuffle partitions + val lastConfWithoutShufflePartitions = + lastMetadata.conf - SQLConf.SHUFFLE_PARTITIONS.key + val previousConfWithoutShufflePartitions = + previousMetadata.conf - SQLConf.SHUFFLE_PARTITIONS.key + assert(lastConfWithoutShufflePartitions == previousConfWithoutShufflePartitions, + "All configs except shuffle partitions should be the same") + + // Verify shuffle partitions are different + assert( + getShufflePartitions(lastMetadata).get != getShufflePartitions(previousMetadata).get, + "Shuffle partitions should be different between batches") + case _ => + fail("Both batches should have metadata") + } + } +}