diff --git a/quixstreams/state/base/transaction.py b/quixstreams/state/base/transaction.py index 432b3922a..614e7ed88 100644 --- a/quixstreams/state/base/transaction.py +++ b/quixstreams/state/base/transaction.py @@ -26,6 +26,8 @@ from quixstreams.state.metadata import ( CHANGELOG_CF_MESSAGE_HEADER, CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, + CHANGELOG_TRANSACTION_END_KEY, + CHANGELOG_TRANSACTION_START_KEY, DEFAULT_PREFIX, SEPARATOR, Marker, @@ -511,6 +513,7 @@ def _prepare(self, processed_offsets: Optional[dict[str, int]]): source_tp_offset_header = json_dumps(processed_offsets) column_families = self._update_cache.get_column_families() + changelog_tx_started = False for cf_name in column_families: headers: Headers = { CHANGELOG_CF_MESSAGE_HEADER: cf_name, @@ -520,6 +523,11 @@ def _prepare(self, processed_offsets: Optional[dict[str, int]]): updates = self._update_cache.get_updates(cf_name=cf_name) for prefix_update_cache in updates.values(): for key, value in prefix_update_cache.items(): + if not changelog_tx_started: + self._changelog_producer.produce( + key=CHANGELOG_TRANSACTION_START_KEY + ) + changelog_tx_started = True self._changelog_producer.produce( key=key, value=value, @@ -528,12 +536,20 @@ def _prepare(self, processed_offsets: Optional[dict[str, int]]): deletes = self._update_cache.get_deletes(cf_name=cf_name) for key in deletes: + if not changelog_tx_started: + self._changelog_producer.produce( + key=CHANGELOG_TRANSACTION_START_KEY + ) + changelog_tx_started = True self._changelog_producer.produce( key=key, value=None, headers=headers, ) + if changelog_tx_started: + self._changelog_producer.produce(key=CHANGELOG_TRANSACTION_END_KEY) + @validate_transaction_status( PartitionTransactionStatus.STARTED, PartitionTransactionStatus.PREPARED ) diff --git a/quixstreams/state/metadata.py b/quixstreams/state/metadata.py index 09dd70e72..ce57aaa23 100644 --- a/quixstreams/state/metadata.py +++ b/quixstreams/state/metadata.py @@ -5,6 +5,8 @@ CHANGELOG_CF_MESSAGE_HEADER = "__column_family__" CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER = "__processed_tp_offsets__" +CHANGELOG_TRANSACTION_START_KEY = b"__transaction_start__" +CHANGELOG_TRANSACTION_END_KEY = b"__transaction_end__" METADATA_CF_NAME = "__metadata__" DEFAULT_PREFIX = b"" diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py index b79c30188..c301131ea 100644 --- a/quixstreams/state/recovery.py +++ b/quixstreams/state/recovery.py @@ -23,6 +23,8 @@ from .metadata import ( CHANGELOG_CF_MESSAGE_HEADER, CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER, + CHANGELOG_TRANSACTION_END_KEY, + CHANGELOG_TRANSACTION_START_KEY, ) logger = logging.getLogger(__name__) @@ -552,12 +554,27 @@ def _recovery_loop(self) -> None: A RecoveryPartition is unassigned immediately once fully updated. """ + changelog_tx_started = False + changelog_tx_buffer: list[SuccessfulConfluentKafkaMessageProto] = [] while self.recovering: self._log_recovery_progress() if (msg := self._consumer.poll(1)) is None: self._update_recovery_status() + continue + + msg = raise_for_msg_error(msg) + + if msg.key() == CHANGELOG_TRANSACTION_START_KEY: + changelog_tx_started = True + elif msg.key() == CHANGELOG_TRANSACTION_END_KEY: + for msg in changelog_tx_buffer: + rp = self._recovery_partitions[msg.partition()][msg.topic()] + rp.recover_from_changelog_message(changelog_message=msg) + changelog_tx_started = False + changelog_tx_buffer = [] + elif changelog_tx_started: + changelog_tx_buffer.append(msg) else: - msg = raise_for_msg_error(msg) rp = self._recovery_partitions[msg.partition()][msg.topic()] rp.recover_from_changelog_message(changelog_message=msg)