Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/amp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def query_and_load_streaming(

# Optionally wrap with reorg detection
if with_reorg_detection:
stream_iterator = ReorgAwareStream(stream_iterator)
stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark)
self.logger.info('Reorg detection enabled for streaming query')

# Start continuous loading with checkpoint support
Expand Down
53 changes: 43 additions & 10 deletions src/amp/streaming/reorg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Iterator, List

from .iterator import StreamingResultIterator
from .types import BlockRange, ResponseBatch
from .types import BlockRange, ResponseBatch, ResumeWatermark


class ReorgAwareStream:
Expand All @@ -16,20 +16,32 @@ class ReorgAwareStream:
This class monitors the block ranges in consecutive batches to detect chain
reorganizations (reorgs). When a reorg is detected, a ResponseBatch with
is_reorg=True is emitted containing the invalidation ranges.

Supports cross-restart reorg detection by initializing from a resume watermark
that contains the last known block hashes from persistent state.
"""

def __init__(self, stream_iterator: StreamingResultIterator):
def __init__(self, stream_iterator: StreamingResultIterator, resume_watermark: ResumeWatermark = None):
"""
Initialize the reorg-aware stream.

Args:
stream_iterator: The underlying streaming result iterator
resume_watermark: Optional watermark from persistent state (LMDB) containing
last known block ranges with hashes for cross-restart reorg detection
"""
self.stream_iterator = stream_iterator
# Track the latest range for each network
self.prev_ranges_by_network: Dict[str, BlockRange] = {}
self.logger = logging.getLogger(__name__)

if resume_watermark:
for block_range in resume_watermark.ranges:
self.prev_ranges_by_network[block_range.network] = block_range
self.logger.debug(
f'Initialized reorg detection for {block_range.network} '
f'from block {block_range.end} hash {block_range.hash}'
)

def __iter__(self) -> Iterator[ResponseBatch]:
"""Return iterator instance"""
return self
Expand Down Expand Up @@ -89,9 +101,9 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
"""
Detect reorganizations by comparing current ranges with previous ranges.

A reorg is detected when:
- A range starts at or before the end of the previous range for the same network
- The range is different from the previous range
A reorg is detected when either:
1. Block number overlap: current range starts at or before previous range end
2. Hash mismatch: server's prev_hash doesn't match our stored hash (cross-restart detection)

Args:
current_ranges: Block ranges from the current batch
Expand All @@ -102,18 +114,39 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
invalidation_ranges = []

for current_range in current_ranges:
# Get the previous range for this network
prev_range = self.prev_ranges_by_network.get(current_range.network)

if prev_range:
# Check if this indicates a reorg
is_reorg = False

# Detection 1: Block number overlap (original logic)
if current_range != prev_range and current_range.start <= prev_range.end:
# Reorg detected - create invalidation range
# Invalidate from the start of the current range to the max end
is_reorg = True
self.logger.info(
f'Reorg detected via block overlap: {current_range.network} '
f'current start {current_range.start} <= prev end {prev_range.end}'
)

# Detection 2: Hash mismatch (cross-restart detection)
# Server sends prev_hash = hash of block before current range
# If it doesn't match our stored hash, chain has changed
elif (
current_range.prev_hash is not None
and prev_range.hash is not None
and current_range.prev_hash != prev_range.hash
):
is_reorg = True
self.logger.info(
f'Reorg detected via hash mismatch: {current_range.network} '
f'server prev_hash {current_range.prev_hash} != stored hash {prev_range.hash}'
)

if is_reorg:
invalidation = BlockRange(
network=current_range.network,
start=current_range.start,
end=max(current_range.end, prev_range.end),
hash=prev_range.hash,
)
invalidation_ranges.append(invalidation)

Expand Down
102 changes: 102 additions & 0 deletions tests/unit/test_streaming_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,105 @@ class MockIterator:
stream = ReorgAwareStream(MockIterator())

assert stream._is_duplicate_batch([]) == False

def test_init_from_resume_watermark(self):
"""Test initialization from resume watermark for cross-restart reorg detection"""

class MockIterator:
pass

watermark = ResumeWatermark(
ranges=[
BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'),
BlockRange(network='polygon', start=50, end=150, hash='0xdef456'),
]
)

stream = ReorgAwareStream(MockIterator(), resume_watermark=watermark)

assert 'ethereum' in stream.prev_ranges_by_network
assert 'polygon' in stream.prev_ranges_by_network
assert stream.prev_ranges_by_network['ethereum'].hash == '0xabc123'
assert stream.prev_ranges_by_network['polygon'].hash == '0xdef456'

def test_detect_reorg_hash_mismatch(self):
"""Test reorg detection via hash mismatch (cross-restart detection)"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xoriginal'),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xdifferent'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 1
assert invalidations[0].network == 'ethereum'
assert invalidations[0].hash == '0xoriginal'

def test_detect_reorg_hash_match_no_reorg(self):
"""Test no reorg when hashes match across restart"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xsame'),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsame'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0

def test_detect_reorg_hash_mismatch_with_none_prev_hash(self):
"""Test no reorg detection when server prev_hash is None (genesis block)"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=0, end=0, hash='0xgenesis'),
}

current_ranges = [
BlockRange(network='ethereum', start=1, end=100, prev_hash=None),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0

def test_detect_reorg_hash_mismatch_with_none_stored_hash(self):
"""Test no reorg detection when stored hash is None"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash=None),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsome_hash'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0
Loading