Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions examples/ezmsg_leaky_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Leaky Subscriber Example
#
# This example demonstrates the "leaky subscriber" feature, which allows
# slow consumers to drop old messages rather than blocking fast producers.
#
# Scenario:
# - A fast publisher produces messages at ~10 Hz (every 100ms)
# - A slow subscriber processes messages at ~1 Hz (1000ms per message)
# - Without leaky mode: the publisher would be blocked by backpressure
# - With leaky mode: old messages are dropped, subscriber always gets recent data
#
# This is useful for real-time applications where you want the latest data
# rather than processing a growing backlog of stale messages.

import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field

import ezmsg.core as ez


@dataclass
class TimestampedMessage:
"""A message with sequence number and timestamp for tracking latency."""

seq: int
created_at: float = field(default_factory=lambda: asyncio.get_event_loop().time())


class FastPublisherSettings(ez.Settings):
num_messages: int = 20
publish_interval_sec: float = 0.1 # 10 Hz


class FastPublisher(ez.Unit):
"""Publishes messages at ~10 Hz."""

SETTINGS = FastPublisherSettings

OUTPUT = ez.OutputStream(TimestampedMessage, num_buffers=32)

@ez.publisher(OUTPUT)
async def publish(self) -> AsyncGenerator:
# Small delay to ensure subscriber is ready
await asyncio.sleep(0.5)

for seq in range(self.SETTINGS.num_messages):
msg = TimestampedMessage(seq=seq)
print(f"[Publisher] Sending seq={seq}", flush=True)
yield (self.OUTPUT, msg)
await asyncio.sleep(self.SETTINGS.publish_interval_sec)

print("[Publisher] Done sending all messages", flush=True)
raise ez.Complete


class SlowSubscriberSettings(ez.Settings):
process_time_sec: float = 1.0 # Simulates slow processing at ~1 Hz
expected_messages: int = 20


class SlowSubscriberState(ez.State):
received_count: int = 0
received_seqs: list = None
total_latency: float = 0.0


class SlowSubscriber(ez.Unit):
"""
A slow subscriber that takes 1 second to process each message.

Uses a leaky InputStream to drop old messages when it can't keep up,
ensuring it always processes relatively recent data.
"""

SETTINGS = SlowSubscriberSettings
STATE = SlowSubscriberState

# Leaky input stream with max queue of 3 messages
# When the queue fills up, oldest messages are dropped
INPUT = ez.InputStream(TimestampedMessage, leaky=True, max_queue=3)

async def initialize(self) -> None:
self.STATE.received_seqs = []

@ez.subscriber(INPUT)
async def on_message(self, msg: TimestampedMessage) -> None:
now = asyncio.get_event_loop().time()
latency_ms = (now - msg.created_at) * 1000

self.STATE.received_count += 1
self.STATE.total_latency += latency_ms
self.STATE.received_seqs.append(msg.seq)

print(
f"[Subscriber] Processing seq={msg.seq:3d}, "
f"latency={latency_ms:6.0f}ms",
flush=True,
)

# Simulate slow processing
await asyncio.sleep(self.SETTINGS.process_time_sec)

# Terminate after receiving the last message
if msg.seq == self.SETTINGS.expected_messages - 1:
raise ez.NormalTermination

async def shutdown(self) -> None:
dropped = self.SETTINGS.expected_messages - self.STATE.received_count
avg_latency = (
self.STATE.total_latency / self.STATE.received_count
if self.STATE.received_count > 0
else 0
)

print("\n" + "=" * 60, flush=True)
print("LEAKY SUBSCRIBER SUMMARY", flush=True)
print("=" * 60, flush=True)
print(f" Messages published: {self.SETTINGS.expected_messages}", flush=True)
print(f" Messages received: {self.STATE.received_count}", flush=True)
print(f" Messages dropped: {dropped}", flush=True)
print(f" Sequences received: {self.STATE.received_seqs}", flush=True)
print(f" Average latency: {avg_latency:.0f}ms", flush=True)
print("=" * 60, flush=True)
print(
"\nNote: With leaky=True, the subscriber drops old messages to stay\n"
" current. Without it, backpressure would slow the publisher.",
flush=True,
)


class LeakyDemo(ez.Collection):
"""Demo system with a fast publisher and slow leaky subscriber."""

SETTINGS = FastPublisherSettings

PUB = FastPublisher()
SUB = SlowSubscriber()

def configure(self) -> None:
num_msgs = self.SETTINGS.num_messages
self.PUB.apply_settings(
FastPublisherSettings(
num_messages=num_msgs,
publish_interval_sec=self.SETTINGS.publish_interval_sec,
)
)
self.SUB.apply_settings(
SlowSubscriberSettings(process_time_sec=1.0, expected_messages=num_msgs)
)

def network(self) -> ez.NetworkDefinition:
return ((self.PUB.OUTPUT, self.SUB.INPUT),)


if __name__ == "__main__":
print("Leaky Subscriber Demo", flush=True)
print("=" * 60, flush=True)
print("Publisher: 20 messages at 10 Hz (100ms intervals)", flush=True)
print("Subscriber: Processes at 1 Hz (1000ms per message)", flush=True)
print("Queue: max_queue=3, leaky=True", flush=True)
print("=" * 60, flush=True)
print("\nExpected behavior:", flush=True)
print("- Publisher sends 20 messages over ~2 seconds", flush=True)
print("- Subscriber can only process ~1 message per second", flush=True)
print("- Many messages will be dropped to keep subscriber current", flush=True)
print("=" * 60 + "\n", flush=True)

settings = FastPublisherSettings(num_messages=20, publish_interval_sec=0.1)
system = LeakyDemo(settings)
ez.run(DEMO=system)
28 changes: 26 additions & 2 deletions src/ezmsg/core/backendprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .stream import Stream, InputStream, OutputStream
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR
from .messagechannel import LeakyQueue

from .graphcontext import GraphContext
from .pubclient import Publisher
Expand Down Expand Up @@ -201,7 +202,12 @@ async def setup_state():
if isinstance(stream, InputStream):
logger.debug(f"Creating Subscriber from {stream}")
sub = asyncio.run_coroutine_threadsafe(
context.subscriber(stream.address), loop
context.subscriber(
stream.address,
leaky=stream.leaky,
max_queue=stream.max_queue,
),
loop,
).result()
task_name = f"SUBSCRIBER|{stream.address}"
coro_callables[task_name] = partial(
Expand Down Expand Up @@ -406,12 +412,20 @@ async def handle_subscriber(
:param callables: Set of async callables to invoke with messages.
:type callables: set[Callable[..., Coroutine[Any, Any, None]]]
"""
# Leaky subscribers use recv() to copy and release backpressure immediately,
# allowing publishers to continue without blocking during slow processing.
# Non-leaky subscribers use recv_zero_copy() to hold backpressure during
# processing, which provides zero-copy performance but applies backpressure.
is_leaky = isinstance(sub._incoming, LeakyQueue)

while True:
if not callables:
sub.close()
await sub.wait_closed()
break
async with sub.recv_zero_copy() as msg:

if is_leaky:
msg = await sub.recv()
try:
for callable in list(callables):
try:
Expand All @@ -420,6 +434,16 @@ async def handle_subscriber(
callables.remove(callable)
finally:
del msg
else:
async with sub.recv_zero_copy() as msg:
try:
for callable in list(callables):
try:
await callable(msg)
except (Complete, NormalTermination):
callables.remove(callable)
finally:
del msg

if len(callables) > 1:
await asyncio.sleep(0)
Expand Down
88 changes: 77 additions & 11 deletions src/ezmsg/core/messagechannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,48 @@
logger = logging.getLogger("ezmsg")


NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]]
class LeakyQueue(asyncio.Queue[typing.Tuple[UUID, int]]):
"""
An asyncio.Queue that drops oldest items when full.

When putting a new item into a full queue, the oldest item is
dropped to make room.

:param maxsize: Maximum queue size (must be positive)
:param on_drop: Optional callback called with dropped item when dropping
"""

def __init__(
self,
maxsize: int,
on_drop: typing.Callable[[typing.Any], None] | None = None,
):
super().__init__(maxsize=maxsize)
self._on_drop = on_drop

def _drop_oldest(self) -> None:
"""Drop the oldest item from the queue, calling on_drop if set."""
try:
dropped = self.get_nowait()
if self._on_drop is not None:
self._on_drop(dropped)
except asyncio.QueueEmpty:
pass

async def put(self, item: typing.Tuple[UUID, int]) -> None:
"""Put an item into the queue, dropping oldest if full."""
if self.full():
self._drop_oldest()
await super().put(item)

def put_nowait(self, item: typing.Tuple[UUID, int]) -> None:
"""Put an item without blocking, dropping oldest if full."""
if self.full():
self._drop_oldest()
super().put_nowait(item)


NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] | LeakyQueue


class Channel:
Expand Down Expand Up @@ -310,16 +351,41 @@ def get(
try:
yield self.cache[msg_id]
finally:
buf_idx = msg_id % self.num_buffers
self.backpressure.free(client_id, buf_idx)
if self.backpressure.buffers[buf_idx].is_empty:
self.cache.release(msg_id)

# If pub is in same process as this channel, avoid TCP
if self._local_backpressure is not None:
self._local_backpressure.free(self.id, buf_idx)
else:
self._acknowledge(msg_id)
self._release_backpressure(msg_id, client_id)

def release_without_get(self, msg_id: int, client_id: UUID) -> None:
"""
Release backpressure for a message without retrieving it.

Used by leaky subscribers when dropping notifications to ensure
backpressure is properly released for messages that will never be read.

:param msg_id: Message ID to release
:type msg_id: int
:param client_id: UUID of client releasing this message
:type client_id: UUID
"""
self._release_backpressure(msg_id, client_id)

def _release_backpressure(self, msg_id: int, client_id: UUID) -> None:
"""
Internal method to release backpressure for a message.

:param msg_id: Message ID to release
:type msg_id: int
:param client_id: UUID of client releasing this message
:type client_id: UUID
"""
buf_idx = msg_id % self.num_buffers
self.backpressure.free(client_id, buf_idx)
if self.backpressure.buffers[buf_idx].is_empty:
self.cache.release(msg_id)

# If pub is in same process as this channel, avoid TCP
if self._local_backpressure is not None:
self._local_backpressure.free(self.id, buf_idx)
else:
self._acknowledge(msg_id)

def _acknowledge(self, msg_id: int) -> None:
try:
Expand Down
Loading