From f0e4c84fe8728da4be3169fdf9917354512b1e2e Mon Sep 17 00:00:00 2001 From: Preston Peranich Date: Thu, 15 Jan 2026 10:18:25 -0500 Subject: [PATCH 1/3] feat: implement leaky subscribers. --- examples/ezmsg_leaky_subscriber.py | 171 +++++++++++++ src/ezmsg/core/backendprocess.py | 7 +- src/ezmsg/core/messagechannel.py | 90 ++++++- src/ezmsg/core/stream.py | 66 ++++- src/ezmsg/core/subclient.py | 43 +++- src/ezmsg/core/unit.py | 4 + tests/test_leaky_subscriber.py | 388 +++++++++++++++++++++++++++++ 7 files changed, 746 insertions(+), 23 deletions(-) create mode 100644 examples/ezmsg_leaky_subscriber.py create mode 100644 tests/test_leaky_subscriber.py diff --git a/examples/ezmsg_leaky_subscriber.py b/examples/ezmsg_leaky_subscriber.py new file mode 100644 index 0000000..5aaeb56 --- /dev/null +++ b/examples/ezmsg_leaky_subscriber.py @@ -0,0 +1,171 @@ +# Leaky Subscriber Example +# +# This example demonstrates the "leaky subscriber" feature, which allows +# slow consumers to drop old messages rather than blocking fast producers. +# +# Scenario: +# - A fast publisher produces messages at ~10 Hz (every 100ms) +# - A slow subscriber processes messages at ~1 Hz (1000ms per message) +# - Without leaky mode: the publisher would be blocked by backpressure +# - With leaky mode: old messages are dropped, subscriber always gets recent data +# +# This is useful for real-time applications where you want the latest data +# rather than processing a growing backlog of stale messages. + +import asyncio +from collections.abc import AsyncGenerator +from dataclasses import dataclass, field + +import ezmsg.core as ez + + +@dataclass +class TimestampedMessage: + """A message with sequence number and timestamp for tracking latency.""" + + seq: int + created_at: float = field(default_factory=lambda: asyncio.get_event_loop().time()) + + +class FastPublisherSettings(ez.Settings): + num_messages: int = 20 + publish_interval_sec: float = 0.1 # 10 Hz + + +class FastPublisher(ez.Unit): + """Publishes messages at ~10 Hz.""" + + SETTINGS = FastPublisherSettings + + OUTPUT = ez.OutputStream(TimestampedMessage, num_buffers=32) + + @ez.publisher(OUTPUT) + async def publish(self) -> AsyncGenerator: + # Small delay to ensure subscriber is ready + await asyncio.sleep(0.5) + + for seq in range(self.SETTINGS.num_messages): + msg = TimestampedMessage(seq=seq) + print(f"[Publisher] Sending seq={seq}", flush=True) + yield (self.OUTPUT, msg) + await asyncio.sleep(self.SETTINGS.publish_interval_sec) + + print("[Publisher] Done sending all messages", flush=True) + raise ez.Complete + + +class SlowSubscriberSettings(ez.Settings): + process_time_sec: float = 1.0 # Simulates slow processing at ~1 Hz + expected_messages: int = 20 + + +class SlowSubscriberState(ez.State): + received_count: int = 0 + received_seqs: list = None + total_latency: float = 0.0 + + +class SlowSubscriber(ez.Unit): + """ + A slow subscriber that takes 1 second to process each message. + + Uses a leaky InputStream to drop old messages when it can't keep up, + ensuring it always processes relatively recent data. + """ + + SETTINGS = SlowSubscriberSettings + STATE = SlowSubscriberState + + # Leaky input stream with max queue of 3 messages + # When the queue fills up, oldest messages are dropped + INPUT = ez.InputStream(TimestampedMessage, leaky=True, max_queue=3) + + async def initialize(self) -> None: + self.STATE.received_seqs = [] + + @ez.subscriber(INPUT) + async def on_message(self, msg: TimestampedMessage) -> None: + now = asyncio.get_event_loop().time() + latency_ms = (now - msg.created_at) * 1000 + + self.STATE.received_count += 1 + self.STATE.total_latency += latency_ms + self.STATE.received_seqs.append(msg.seq) + + print( + f"[Subscriber] Processing seq={msg.seq:3d}, " + f"latency={latency_ms:6.0f}ms", + flush=True, + ) + + # Simulate slow processing + await asyncio.sleep(self.SETTINGS.process_time_sec) + + # Terminate after receiving the last message + if msg.seq == self.SETTINGS.expected_messages - 1: + raise ez.NormalTermination + + async def shutdown(self) -> None: + dropped = self.SETTINGS.expected_messages - self.STATE.received_count + avg_latency = ( + self.STATE.total_latency / self.STATE.received_count + if self.STATE.received_count > 0 + else 0 + ) + + print("\n" + "=" * 60, flush=True) + print("LEAKY SUBSCRIBER SUMMARY", flush=True) + print("=" * 60, flush=True) + print(f" Messages published: {self.SETTINGS.expected_messages}", flush=True) + print(f" Messages received: {self.STATE.received_count}", flush=True) + print(f" Messages dropped: {dropped}", flush=True) + print(f" Sequences received: {self.STATE.received_seqs}", flush=True) + print(f" Average latency: {avg_latency:.0f}ms", flush=True) + print("=" * 60, flush=True) + print( + "\nNote: With leaky=True, the subscriber drops old messages to stay\n" + " current. Without it, backpressure would slow the publisher.", + flush=True, + ) + + +class LeakyDemo(ez.Collection): + """Demo system with a fast publisher and slow leaky subscriber.""" + + SETTINGS = FastPublisherSettings + + PUB = FastPublisher() + SUB = SlowSubscriber() + + def configure(self) -> None: + num_msgs = self.SETTINGS.num_messages + self.PUB.apply_settings( + FastPublisherSettings( + num_messages=num_msgs, + publish_interval_sec=self.SETTINGS.publish_interval_sec, + ) + ) + self.SUB.apply_settings( + SlowSubscriberSettings(process_time_sec=1.0, expected_messages=num_msgs) + ) + + def network(self) -> ez.NetworkDefinition: + return ((self.PUB.OUTPUT, self.SUB.INPUT),) + + +if __name__ == "__main__": + print("Leaky Subscriber Demo", flush=True) + print("=" * 60, flush=True) + print("Publisher: 20 messages at 10 Hz (100ms intervals)", flush=True) + print("Subscriber: Processes at 1 Hz (1000ms per message)", flush=True) + print("Queue: max_queue=3, leaky=True", flush=True) + print("=" * 60, flush=True) + print("\nExpected behavior:", flush=True) + print("- Publisher sends 20 messages over ~2 seconds", flush=True) + print("- Subscriber can only process ~1 message per second", flush=True) + print("- Many messages will be dropped to keep subscriber current", flush=True) + print("=" * 60 + "\n", flush=True) + + settings = FastPublisherSettings(num_messages=20, publish_interval_sec=0.1) + system = LeakyDemo(settings) + ez.run(DEMO=system) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index ba93ead..cf79394 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -201,7 +201,12 @@ async def setup_state(): if isinstance(stream, InputStream): logger.debug(f"Creating Subscriber from {stream}") sub = asyncio.run_coroutine_threadsafe( - context.subscriber(stream.address), loop + context.subscriber( + stream.address, + leaky=stream.leaky, + max_queue=stream.max_queue, + ), + loop, ).result() task_name = f"SUBSCRIBER|{stream.address}" coro_callables[task_name] = partial( diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index e016f10..8d10948 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -9,7 +9,7 @@ from .shm import SHMContext from .messagemarshal import MessageMarshal from .backpressure import Backpressure -from .messagecache import MessageCache +from .messagecache import MessageCache, CacheMiss from .graphserver import GraphService from .netprotocol import ( Command, @@ -25,7 +25,48 @@ logger = logging.getLogger("ezmsg") -NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] +class LeakyQueue(asyncio.Queue[typing.Tuple[UUID, int]]): + """ + An asyncio.Queue that drops oldest items when full. + + When putting a new item into a full queue, the oldest item is + dropped to make room. + + :param maxsize: Maximum queue size (must be positive) + :param on_drop: Optional callback called with dropped item when dropping + """ + + def __init__( + self, + maxsize: int, + on_drop: typing.Callable[[typing.Any], None] | None = None, + ): + super().__init__(maxsize=maxsize) + self._on_drop = on_drop + + def _drop_oldest(self) -> None: + """Drop the oldest item from the queue, calling on_drop if set.""" + try: + dropped = self.get_nowait() + if self._on_drop is not None: + self._on_drop(dropped) + except asyncio.QueueEmpty: + pass + + async def put(self, item: typing.Tuple[UUID, int]) -> None: + """Put an item into the queue, dropping oldest if full.""" + if self.full(): + self._drop_oldest() + await super().put(item) + + def put_nowait(self, item: typing.Tuple[UUID, int]) -> None: + """Put an item without blocking, dropping oldest if full.""" + if self.full(): + self._drop_oldest() + super().put_nowait(item) + + +NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] | LeakyQueue class Channel: @@ -310,16 +351,41 @@ def get( try: yield self.cache[msg_id] finally: - buf_idx = msg_id % self.num_buffers - self.backpressure.free(client_id, buf_idx) - if self.backpressure.buffers[buf_idx].is_empty: - self.cache.release(msg_id) - - # If pub is in same process as this channel, avoid TCP - if self._local_backpressure is not None: - self._local_backpressure.free(self.id, buf_idx) - else: - self._acknowledge(msg_id) + self._release_backpressure(msg_id, client_id) + + def release_without_get(self, msg_id: int, client_id: UUID) -> None: + """ + Release backpressure for a message without retrieving it. + + Used by leaky subscribers when dropping notifications to ensure + backpressure is properly released for messages that will never be read. + + :param msg_id: Message ID to release + :type msg_id: int + :param client_id: UUID of client releasing this message + :type client_id: UUID + """ + self._release_backpressure(msg_id, client_id) + + def _release_backpressure(self, msg_id: int, client_id: UUID) -> None: + """ + Internal method to release backpressure for a message. + + :param msg_id: Message ID to release + :type msg_id: int + :param client_id: UUID of client releasing this message + :type client_id: UUID + """ + buf_idx = msg_id % self.num_buffers + self.backpressure.free(client_id, buf_idx) + if self.backpressure.buffers[buf_idx].is_empty: + self.cache.release(msg_id) + + # If pub is in same process as this channel, avoid TCP + if self._local_backpressure is not None: + self._local_backpressure.free(self.id, buf_idx) + else: + self._acknowledge(msg_id) def _acknowledge(self, msg_id: int) -> None: try: diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index 2af20f3..ae7d9e1 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -33,12 +33,76 @@ class InputStream(Stream): InputStream represents a channel that receives messages from other components. Units can subscribe to InputStreams to process incoming messages. + Leaky Subscribers + ----------------- + + By default, ezmsg uses backpressure to prevent fast publishers from overwhelming + slow subscribers. When a subscriber can't keep up, the publisher blocks until + the subscriber catches up. This guarantees no message loss but can cause latency + buildup in real-time applications. + + Setting ``leaky=True`` creates a "leaky" subscriber that drops old messages + instead of applying backpressure. This is useful when you need the most recent + data rather than processing a growing backlog of stale messages. + + **Architecture**: The leaky behavior is implemented at the subscriber's + notification queue, *after* message serialization and transport. This means: + + - Publishers still serialize and transmit every message (to shared memory or TCP) + - The Channel still receives and caches every message + - Dropping occurs when the subscriber's notification queue is full + - Backpressure is properly released for dropped messages (ACKs sent to publisher) + + This design ensures that: + + 1. One leaky subscriber doesn't affect other subscribers to the same topic + 2. The publisher's buffer management remains consistent + 3. Backpressure accounting stays correct (no resource leaks) + + **Trade-offs**: Leaky subscribers don't reduce serialization or network overhead; + they prevent slow consumers from blocking fast producers. If you need to reduce + data transfer, consider filtering or downsampling at the publisher level. + + Example usage:: + + # Leaky subscriber that keeps at most 3 pending messages + INPUT = ez.InputStream(MyMessage, leaky=True, max_queue=3) + + @ez.subscriber(INPUT) + async def process(self, msg: MyMessage) -> None: + # Will only see recent messages; older ones dropped if queue fills + await slow_processing(msg) + :param msg_type: The type of messages this input stream will receive :type msg_type: Any + :param leaky: If True, drop oldest messages when queue is full (default: False) + :type leaky: bool + :param max_queue: Maximum queue depth for leaky mode (required if leaky=True) + :type max_queue: int | None """ + leaky: bool + max_queue: int | None + + def __init__( + self, + msg_type: Any, + leaky: bool = False, + max_queue: int | None = None, + ) -> None: + super().__init__(msg_type) + if leaky and max_queue is None: + raise ValueError("max_queue must be set when leaky=True") + if max_queue is not None and max_queue <= 0: + raise ValueError("max_queue must be positive") + self.leaky = leaky + self.max_queue = max_queue + def __repr__(self) -> str: - return f"Input{super().__repr__()}()" + base = f"Input{super().__repr__()}" + if self.leaky: + return f"{base}(leaky=True, max_queue={self.max_queue})" + return f"{base}()" class OutputStream(Stream): diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index fee285c..d88233a 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -8,7 +8,7 @@ from .graphserver import GraphService from .channelmanager import CHANNELS -from .messagechannel import NotificationQueue, Channel +from .messagechannel import NotificationQueue, LeakyQueue, Channel from .netprotocol import ( AddressType, @@ -91,11 +91,13 @@ async def create( return sub def __init__( - self, - id: UUID, - topic: str, - graph_address: AddressType | None, - _guard = None, + self, + id: UUID, + topic: str, + graph_address: AddressType | None, + _guard=None, + leaky: bool = False, + max_queue: int | None = None, **kwargs ) -> None: """ @@ -107,8 +109,12 @@ def __init__( :type id: UUID :param topic: The topic this subscriber listens to. :type topic: str - :param graph_service: Service for graph operations. - :type graph_service: GraphService + :param graph_address: Address of the graph server. + :type graph_address: AddressType | None + :param leaky: If True, drop oldest messages when queue is full. + :type leaky: bool + :param max_queue: Maximum queue size (required if leaky=True). + :type max_queue: int | None :param kwargs: Additional keyword arguments (unused). """ if _guard is not self._SENTINEL: @@ -121,10 +127,29 @@ def __init__( self._graph_address = graph_address self._cur_pubs = set() - self._incoming = asyncio.Queue() self._channels = dict() + if leaky: + self._incoming = LeakyQueue(max_queue, self._handle_dropped_notification) + else: + self._incoming = asyncio.Queue() self._initialized = asyncio.Event() + def _handle_dropped_notification( + self, notification: typing.Tuple[UUID, int] + ) -> None: + """ + Handle a dropped notification by releasing backpressure. + + Called by LeakyQueue when a notification is dropped to ensure + backpressure is properly released for messages that will never be read. + + :param notification: Tuple of (publisher_id, message_id) that was dropped. + :type notification: tuple[UUID, int] + """ + pub_id, msg_id = notification + if pub_id in self._channels: + self._channels[pub_id].release_without_get(msg_id, self.id) + def close(self) -> None: """ Close the subscriber and cancel all associated tasks. diff --git a/src/ezmsg/core/unit.py b/src/ezmsg/core/unit.py index d957d06..f4a77cf 100644 --- a/src/ezmsg/core/unit.py +++ b/src/ezmsg/core/unit.py @@ -20,6 +20,8 @@ TIMEIT_ATTR = "__ez_timeit__" ZERO_COPY_ATTR = "__ez_zerocopy__" PROCESS_ATTR = "__ez_process__" +LEAKY_ATTR = "__ez_leaky__" +MAX_QUEUE_ATTR = "__ez_max_queue__" class UnitMeta(ComponentMeta): @@ -195,6 +197,8 @@ def sub_factory(func): raise Exception(f"{func} cannot subscribe to more than one stream") setattr(func, SUBSCRIBES_ATTR, stream) setattr(func, ZERO_COPY_ATTR, zero_copy) + setattr(func, LEAKY_ATTR, stream.leaky) + setattr(func, MAX_QUEUE_ATTR, stream.max_queue) return task(func) return sub_factory diff --git a/tests/test_leaky_subscriber.py b/tests/test_leaky_subscriber.py new file mode 100644 index 0000000..71c17ad --- /dev/null +++ b/tests/test_leaky_subscriber.py @@ -0,0 +1,388 @@ +"""Tests for leaky subscriber functionality.""" + +import asyncio +from uuid import uuid4 + +import pytest +from ezmsg.core.stream import InputStream + + +def test_input_stream_default_not_leaky(): + """InputStream defaults to non-leaky with no queue limit.""" + stream = InputStream(float) + assert stream.leaky is False + assert stream.max_queue is None + + +def test_input_stream_leaky_configuration(): + """InputStream accepts leaky and max_queue parameters.""" + stream = InputStream(float, leaky=True, max_queue=5) + assert stream.leaky is True + assert stream.max_queue == 5 + + +def test_input_stream_leaky_requires_max_queue(): + """Leaky mode requires max_queue to be set.""" + with pytest.raises(ValueError, match="max_queue must be set"): + InputStream(float, leaky=True) + + +def test_input_stream_max_queue_must_be_positive(): + """max_queue must be a positive integer.""" + with pytest.raises(ValueError, match="must be positive"): + InputStream(float, leaky=True, max_queue=0) + with pytest.raises(ValueError, match="must be positive"): + InputStream(float, leaky=True, max_queue=-1) + + +@pytest.mark.asyncio +async def test_subscriber_leaky_queue_drops_oldest(): + """Leaky subscriber drops oldest messages when queue is full.""" + # Test the leaky queue helper class directly + from ezmsg.core.messagechannel import LeakyQueue + + queue = LeakyQueue(2) + + # Fill the queue + await queue.put(("pub1", 1)) + await queue.put(("pub1", 2)) + assert queue.qsize() == 2 + + # Adding third item should drop the oldest + await queue.put(("pub1", 3)) + assert queue.qsize() == 2 + + # Should get items 2 and 3 (1 was dropped) + item1 = await queue.get() + item2 = await queue.get() + assert item1 == ("pub1", 2) + assert item2 == ("pub1", 3) + + +def test_subscriber_init_creates_leaky_queue(): + """Subscriber creates LeakyQueue when leaky=True.""" + from ezmsg.core.subclient import Subscriber + from ezmsg.core.messagechannel import LeakyQueue + + # Access internal constructor with guard bypass for testing + sub = Subscriber( + id=uuid4(), + topic="test", + graph_address=None, + leaky=True, + max_queue=5, + _guard=Subscriber._SENTINEL, + ) + + assert isinstance(sub._incoming, LeakyQueue) + assert sub._incoming.maxsize == 5 + + +def test_subscriber_init_default_unbounded_queue(): + """Subscriber creates standard asyncio.Queue by default (not LeakyQueue).""" + from ezmsg.core.subclient import Subscriber + from ezmsg.core.messagechannel import LeakyQueue + + sub = Subscriber( + id=uuid4(), + topic="test", + graph_address=None, + _guard=Subscriber._SENTINEL, + ) + + # Default should be a standard asyncio.Queue, not LeakyQueue + assert isinstance(sub._incoming, asyncio.Queue) + assert not isinstance(sub._incoming, LeakyQueue) + assert sub._incoming.maxsize == 0 # 0 means unlimited in asyncio.Queue + + +@pytest.mark.asyncio +async def test_graphcontext_subscriber_passes_leaky_params(): + """GraphContext.subscriber passes leaky and max_queue to Subscriber.create.""" + from unittest.mock import patch, AsyncMock + from ezmsg.core.graphcontext import GraphContext + + ctx = GraphContext(graph_address=None) + + with patch('ezmsg.core.graphcontext.Subscriber.create', new_callable=AsyncMock) as mock_create: + mock_create.return_value = AsyncMock() + + await ctx.subscriber("test_topic", leaky=True, max_queue=3) + + mock_create.assert_called_once_with( + "test_topic", + None, # graph_address + leaky=True, + max_queue=3 + ) + + +def test_subscriber_decorator_extracts_leaky_from_stream(): + """@subscriber decorator extracts leaky config from InputStream.""" + from ezmsg.core.unit import subscriber, LEAKY_ATTR, MAX_QUEUE_ATTR + + INPUT = InputStream(float, leaky=True, max_queue=10) + + @subscriber(INPUT) + async def process(self, msg: float): + pass + + assert getattr(process, LEAKY_ATTR, False) is True + assert getattr(process, MAX_QUEUE_ATTR, None) == 10 + + +def test_subscriber_decorator_default_not_leaky(): + """@subscriber decorator defaults to non-leaky.""" + from ezmsg.core.unit import subscriber, LEAKY_ATTR, MAX_QUEUE_ATTR + + INPUT = InputStream(float) + + @subscriber(INPUT) + async def process(self, msg: float): + pass + + assert getattr(process, LEAKY_ATTR, False) is False + assert getattr(process, MAX_QUEUE_ATTR, None) is None + + +@pytest.mark.asyncio +async def test_leaky_queue_under_load(): + """Test LeakyQueue behavior under simulated load.""" + from ezmsg.core.messagechannel import LeakyQueue + + queue = LeakyQueue(3) + + # Simulate fast producer + for i in range(100): + await queue.put(("pub", i)) + + # Queue should only have last 3 items + assert queue.qsize() == 3 + + items = [] + while not queue.empty(): + items.append(await queue.get()) + + # Should have the most recent items + assert items == [("pub", 97), ("pub", 98), ("pub", 99)] + + +@pytest.mark.asyncio +async def test_leaky_queue_on_drop_callback(): + """LeakyQueue calls on_drop callback when dropping items.""" + from ezmsg.core.messagechannel import LeakyQueue + + dropped_items = [] + + def on_drop(item): + dropped_items.append(item) + + queue = LeakyQueue(2, on_drop) + + # Fill the queue + await queue.put(("pub", 1)) + await queue.put(("pub", 2)) + assert len(dropped_items) == 0 + + # This should trigger drop of item 1 + await queue.put(("pub", 3)) + assert len(dropped_items) == 1 + assert dropped_items[0] == ("pub", 1) + + # This should trigger drop of item 2 + await queue.put(("pub", 4)) + assert len(dropped_items) == 2 + assert dropped_items[1] == ("pub", 2) + + +@pytest.mark.asyncio +async def test_leaky_queue_on_drop_callback_put_nowait(): + """LeakyQueue calls on_drop callback when using put_nowait.""" + from ezmsg.core.messagechannel import LeakyQueue + + dropped_items = [] + + def on_drop(item): + dropped_items.append(item) + + queue = LeakyQueue(2, on_drop) + + # Fill the queue + queue.put_nowait(("pub", 1)) + queue.put_nowait(("pub", 2)) + assert len(dropped_items) == 0 + + # This should trigger drop of item 1 + queue.put_nowait(("pub", 3)) + assert len(dropped_items) == 1 + assert dropped_items[0] == ("pub", 1) + + +def test_channel_release_without_get(): + """Channel.release_without_get frees backpressure without yielding message.""" + from unittest.mock import MagicMock + from ezmsg.core.messagechannel import Channel + + pub_id = uuid4() + client_id = uuid4() + num_buffers = 4 + + # Create channel with guard bypass + chan = Channel( + id=uuid4(), + pub_id=pub_id, + num_buffers=num_buffers, + shm=None, + graph_address=None, + _guard=Channel._SENTINEL, + ) + + # Register a client + chan.register_client(client_id, queue=MagicMock()) + + # Simulate a leased buffer (as if notify_clients was called) + msg_id = 100 + buf_idx = msg_id % num_buffers + chan.backpressure.lease(client_id, buf_idx) + + # Put a message in the cache (required for release to work) + chan.cache.put_local("test_message", msg_id) + + # Verify lease is held + assert not chan.backpressure.available(buf_idx) + + # Mock _acknowledge to verify it gets called + chan._acknowledge = MagicMock() + + # Release without getting the message + chan.release_without_get(msg_id, client_id) + + # Verify backpressure is freed + assert chan.backpressure.available(buf_idx) + + # Verify acknowledge was called (since no local backpressure) + chan._acknowledge.assert_called_once_with(msg_id) + + +def test_subscriber_leaky_queue_has_on_drop_callback(): + """Leaky Subscriber's queue has on_drop callback configured.""" + from ezmsg.core.subclient import Subscriber + from ezmsg.core.messagechannel import LeakyQueue + + sub = Subscriber( + id=uuid4(), + topic="test", + graph_address=None, + leaky=True, + max_queue=5, + _guard=Subscriber._SENTINEL, + ) + + assert isinstance(sub._incoming, LeakyQueue) + assert sub._incoming._on_drop is not None + # The callback should be the subscriber's _handle_dropped_notification method + assert sub._incoming._on_drop == sub._handle_dropped_notification + + +def test_subscriber_handle_dropped_notification_releases_backpressure(): + """Subscriber._handle_dropped_notification calls channel.release_without_get.""" + from unittest.mock import MagicMock + from ezmsg.core.subclient import Subscriber + from ezmsg.core.messagechannel import Channel + + pub_id = uuid4() + sub_id = uuid4() + num_buffers = 4 + + # Create subscriber with leaky queue + sub = Subscriber( + id=sub_id, + topic="test", + graph_address=None, + leaky=True, + max_queue=2, + _guard=Subscriber._SENTINEL, + ) + + # Create a mock channel + mock_channel = MagicMock(spec=Channel) + mock_channel.num_buffers = num_buffers + + # Register the channel with the subscriber + sub._channels[pub_id] = mock_channel + + # Simulate a dropped notification + msg_id = 100 + sub._handle_dropped_notification((pub_id, msg_id)) + + # Verify release_without_get was called on the channel + mock_channel.release_without_get.assert_called_once_with(msg_id, sub_id) + + +def test_leaky_subscriber_backpressure_integration(): + """ + Integration test: leaky subscriber releases backpressure when dropping. + + Simulates the full flow: + 1. Channel notifies subscriber (taking backpressure lease) + 2. Queue fills up + 3. New notification causes old one to be dropped + 4. Dropped notification triggers backpressure release + """ + from unittest.mock import MagicMock + from ezmsg.core.subclient import Subscriber + from ezmsg.core.messagechannel import Channel, LeakyQueue + + pub_id = uuid4() + sub_id = uuid4() + num_buffers = 4 + + # Create a real channel + chan = Channel( + id=uuid4(), + pub_id=pub_id, + num_buffers=num_buffers, + shm=None, + graph_address=None, + _guard=Channel._SENTINEL, + ) + chan._acknowledge = MagicMock() # Mock TCP acknowledgment + + # Create subscriber with leaky queue (max_queue=2) + sub = Subscriber( + id=sub_id, + topic="test", + graph_address=None, + leaky=True, + max_queue=2, + _guard=Subscriber._SENTINEL, + ) + + # Wire the channel to the subscriber + sub._channels[pub_id] = chan + + # Register subscriber with channel (simulating CHANNELS.register) + chan.register_client(sub_id, sub._incoming) + + # Simulate channel notifying subscriber for msg_id=0, 1, 2 + # This is what _notify_clients does + for msg_id in range(3): + buf_idx = msg_id % num_buffers + # Put message in cache (simulating what _publisher_connection does) + chan.cache.put_local(f"message_{msg_id}", msg_id) + chan.backpressure.lease(sub_id, buf_idx) + sub._incoming.put_nowait((pub_id, msg_id)) + + # Queue should have only 2 items (msg_id 1 and 2) + # msg_id 0 should have been dropped and its backpressure released + assert sub._incoming.qsize() == 2 + + # Verify backpressure for buf_idx=0 (msg_id=0) was released + assert chan.backpressure.available(0), "Backpressure for dropped msg should be released" + + # Verify backpressure for buf_idx=1,2 (msg_id=1,2) is still held + assert not chan.backpressure.available(1), "Backpressure for queued msg should be held" + assert not chan.backpressure.available(2), "Backpressure for queued msg should be held" + + # Verify the acknowledge was called for the dropped message + chan._acknowledge.assert_called_once_with(0) From 9faa5b5848758552a2a0b88d2a214c58a6766c39 Mon Sep 17 00:00:00 2001 From: Preston Peranich Date: Thu, 15 Jan 2026 10:19:39 -0500 Subject: [PATCH 2/3] chore: lint fix --- src/ezmsg/core/messagechannel.py | 2 +- tests/test_generator.py | 1 - tests/test_leaky_subscriber.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 8d10948..0059129 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -9,7 +9,7 @@ from .shm import SHMContext from .messagemarshal import MessageMarshal from .backpressure import Backpressure -from .messagecache import MessageCache, CacheMiss +from .messagecache import MessageCache from .graphserver import GraphService from .netprotocol import ( Command, diff --git a/tests/test_generator.py b/tests/test_generator.py index df01300..7736567 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,7 +1,6 @@ from collections.abc import AsyncGenerator, Generator import copy import json -import os import typing import numpy as np diff --git a/tests/test_leaky_subscriber.py b/tests/test_leaky_subscriber.py index 71c17ad..55f5286 100644 --- a/tests/test_leaky_subscriber.py +++ b/tests/test_leaky_subscriber.py @@ -331,7 +331,7 @@ def test_leaky_subscriber_backpressure_integration(): """ from unittest.mock import MagicMock from ezmsg.core.subclient import Subscriber - from ezmsg.core.messagechannel import Channel, LeakyQueue + from ezmsg.core.messagechannel import Channel pub_id = uuid4() sub_id = uuid4() From c66cb075be91f5d4d92975e1c12592166066d824 Mon Sep 17 00:00:00 2001 From: Preston Peranich Date: Thu, 15 Jan 2026 17:22:01 -0500 Subject: [PATCH 3/3] fix: recv (w/ copy) in sub if queue is leaky. --- src/ezmsg/core/backendprocess.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index cf79394..7cbb0a6 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -20,6 +20,7 @@ from .stream import Stream, InputStream, OutputStream from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR +from .messagechannel import LeakyQueue from .graphcontext import GraphContext from .pubclient import Publisher @@ -411,12 +412,20 @@ async def handle_subscriber( :param callables: Set of async callables to invoke with messages. :type callables: set[Callable[..., Coroutine[Any, Any, None]]] """ + # Leaky subscribers use recv() to copy and release backpressure immediately, + # allowing publishers to continue without blocking during slow processing. + # Non-leaky subscribers use recv_zero_copy() to hold backpressure during + # processing, which provides zero-copy performance but applies backpressure. + is_leaky = isinstance(sub._incoming, LeakyQueue) + while True: if not callables: sub.close() await sub.wait_closed() break - async with sub.recv_zero_copy() as msg: + + if is_leaky: + msg = await sub.recv() try: for callable in list(callables): try: @@ -425,6 +434,16 @@ async def handle_subscriber( callables.remove(callable) finally: del msg + else: + async with sub.recv_zero_copy() as msg: + try: + for callable in list(callables): + try: + await callable(msg) + except (Complete, NormalTermination): + callables.remove(callable) + finally: + del msg if len(callables) > 1: await asyncio.sleep(0)