From cdb1f28e407c94e3fd39d76fc90a0ac2c66b7610 Mon Sep 17 00:00:00 2001 From: Krishnanand V P Date: Tue, 13 Jan 2026 16:43:00 +0400 Subject: [PATCH] Add cross-restart reorg detection using persistent state - Initialize ReorgAwareStream from resume_watermark to restore block hashes - Add hash mismatch detection: compare stored hash with server's prev_hash - Pass resume_watermark from client.py to ReorgAwareStream - Add 6 unit tests for cross-restart reorg detection scenarios This enables detecting chain reorgs that occur while the process is down by comparing the stored block hash (from state store) with the server's prev_hash. --- src/amp/client.py | 2 +- src/amp/streaming/reorg.py | 53 ++++++++++++--- tests/unit/test_streaming_types.py | 102 +++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 11 deletions(-) diff --git a/src/amp/client.py b/src/amp/client.py index 2eee462..d58aac5 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -817,7 +817,7 @@ def query_and_load_streaming( # Optionally wrap with reorg detection if with_reorg_detection: - stream_iterator = ReorgAwareStream(stream_iterator) + stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark) self.logger.info('Reorg detection enabled for streaming query') # Start continuous loading with checkpoint support diff --git a/src/amp/streaming/reorg.py b/src/amp/streaming/reorg.py index 9083db7..8672eb0 100644 --- a/src/amp/streaming/reorg.py +++ b/src/amp/streaming/reorg.py @@ -6,7 +6,7 @@ from typing import Dict, Iterator, List from .iterator import StreamingResultIterator -from .types import BlockRange, ResponseBatch +from .types import BlockRange, ResponseBatch, ResumeWatermark class ReorgAwareStream: @@ -16,20 +16,32 @@ class ReorgAwareStream: This class monitors the block ranges in consecutive batches to detect chain reorganizations (reorgs). When a reorg is detected, a ResponseBatch with is_reorg=True is emitted containing the invalidation ranges. + + Supports cross-restart reorg detection by initializing from a resume watermark + that contains the last known block hashes from persistent state. """ - def __init__(self, stream_iterator: StreamingResultIterator): + def __init__(self, stream_iterator: StreamingResultIterator, resume_watermark: ResumeWatermark = None): """ Initialize the reorg-aware stream. Args: stream_iterator: The underlying streaming result iterator + resume_watermark: Optional watermark from persistent state (LMDB) containing + last known block ranges with hashes for cross-restart reorg detection """ self.stream_iterator = stream_iterator - # Track the latest range for each network self.prev_ranges_by_network: Dict[str, BlockRange] = {} self.logger = logging.getLogger(__name__) + if resume_watermark: + for block_range in resume_watermark.ranges: + self.prev_ranges_by_network[block_range.network] = block_range + self.logger.debug( + f'Initialized reorg detection for {block_range.network} ' + f'from block {block_range.end} hash {block_range.hash}' + ) + def __iter__(self) -> Iterator[ResponseBatch]: """Return iterator instance""" return self @@ -89,9 +101,9 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]: """ Detect reorganizations by comparing current ranges with previous ranges. - A reorg is detected when: - - A range starts at or before the end of the previous range for the same network - - The range is different from the previous range + A reorg is detected when either: + 1. Block number overlap: current range starts at or before previous range end + 2. Hash mismatch: server's prev_hash doesn't match our stored hash (cross-restart detection) Args: current_ranges: Block ranges from the current batch @@ -102,18 +114,39 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]: invalidation_ranges = [] for current_range in current_ranges: - # Get the previous range for this network prev_range = self.prev_ranges_by_network.get(current_range.network) if prev_range: - # Check if this indicates a reorg + is_reorg = False + + # Detection 1: Block number overlap (original logic) if current_range != prev_range and current_range.start <= prev_range.end: - # Reorg detected - create invalidation range - # Invalidate from the start of the current range to the max end + is_reorg = True + self.logger.info( + f'Reorg detected via block overlap: {current_range.network} ' + f'current start {current_range.start} <= prev end {prev_range.end}' + ) + + # Detection 2: Hash mismatch (cross-restart detection) + # Server sends prev_hash = hash of block before current range + # If it doesn't match our stored hash, chain has changed + elif ( + current_range.prev_hash is not None + and prev_range.hash is not None + and current_range.prev_hash != prev_range.hash + ): + is_reorg = True + self.logger.info( + f'Reorg detected via hash mismatch: {current_range.network} ' + f'server prev_hash {current_range.prev_hash} != stored hash {prev_range.hash}' + ) + + if is_reorg: invalidation = BlockRange( network=current_range.network, start=current_range.start, end=max(current_range.end, prev_range.end), + hash=prev_range.hash, ) invalidation_ranges.append(invalidation) diff --git a/tests/unit/test_streaming_types.py b/tests/unit/test_streaming_types.py index 47eede2..e4cc1fe 100644 --- a/tests/unit/test_streaming_types.py +++ b/tests/unit/test_streaming_types.py @@ -619,3 +619,105 @@ class MockIterator: stream = ReorgAwareStream(MockIterator()) assert stream._is_duplicate_batch([]) == False + + def test_init_from_resume_watermark(self): + """Test initialization from resume watermark for cross-restart reorg detection""" + + class MockIterator: + pass + + watermark = ResumeWatermark( + ranges=[ + BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'), + BlockRange(network='polygon', start=50, end=150, hash='0xdef456'), + ] + ) + + stream = ReorgAwareStream(MockIterator(), resume_watermark=watermark) + + assert 'ethereum' in stream.prev_ranges_by_network + assert 'polygon' in stream.prev_ranges_by_network + assert stream.prev_ranges_by_network['ethereum'].hash == '0xabc123' + assert stream.prev_ranges_by_network['polygon'].hash == '0xdef456' + + def test_detect_reorg_hash_mismatch(self): + """Test reorg detection via hash mismatch (cross-restart detection)""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xoriginal'), + } + + current_ranges = [ + BlockRange(network='ethereum', start=201, end=300, prev_hash='0xdifferent'), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 1 + assert invalidations[0].network == 'ethereum' + assert invalidations[0].hash == '0xoriginal' + + def test_detect_reorg_hash_match_no_reorg(self): + """Test no reorg when hashes match across restart""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xsame'), + } + + current_ranges = [ + BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsame'), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 0 + + def test_detect_reorg_hash_mismatch_with_none_prev_hash(self): + """Test no reorg detection when server prev_hash is None (genesis block)""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=0, end=0, hash='0xgenesis'), + } + + current_ranges = [ + BlockRange(network='ethereum', start=1, end=100, prev_hash=None), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 0 + + def test_detect_reorg_hash_mismatch_with_none_stored_hash(self): + """Test no reorg detection when stored hash is None""" + + class MockIterator: + pass + + stream = ReorgAwareStream(MockIterator()) + + stream.prev_ranges_by_network = { + 'ethereum': BlockRange(network='ethereum', start=100, end=200, hash=None), + } + + current_ranges = [ + BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsome_hash'), + ] + + invalidations = stream._detect_reorg(current_ranges) + + assert len(invalidations) == 0