diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 3b6d32f6..3c04e930 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -2,3 +2,5 @@ 5b4c220154d33cac15a75d3d7d978a75e67f2b8a # Ran `black` formatter on entire codebase and fixed non-standard line-endings with `dos2unix`. 4e4f20be40a73f2162a565732bae76ea0c812739 +# chore: ruff formatting +5ca5711e7714042f23d1286abf946217d75318c2 diff --git a/docs/source/conf.py b/docs/source/conf.py index d7a569fc..e6e8e389 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,6 @@ # Configuration file for the Sphinx documentation builder. import os -import sys # -- Project information -------------------------- @@ -75,7 +74,7 @@ html_static_path = ["_static"] # Timestamp is inserted at every page bottom in this strftime format. -html_last_updated_fmt = '%Y-%m-%d' +html_last_updated_fmt = "%Y-%m-%d" # -- Options for EPUB output -------------------------- epub_show_urls = "footnote" @@ -95,8 +94,10 @@ def linkcode_resolve(domain, info): else: return f"{code_url}src/{filename}.py" + # -- Options for graphviz ----------------------------- graphviz_output_format = "svg" + def setup(app): - app.add_css_file("custom.css") \ No newline at end of file + app.add_css_file("custom.css") diff --git a/examples/ezmsg_attach.py b/examples/ezmsg_attach.py index 09fe4786..b75241be 100644 --- a/examples/ezmsg_attach.py +++ b/examples/ezmsg_attach.py @@ -5,4 +5,4 @@ if __name__ == "__main__": print("This example attaches to the system created/run by ezmsg_toy.py.") log = DebugLog() - ez.run(log, connections=(("TestSystem/PING/OUTPUT", log.INPUT),)) + ez.run(LOG=log, connections=(("GLOBAL_PING_TOPIC", log.INPUT),)) diff --git a/examples/ezmsg_configs.py b/examples/ezmsg_configs.py index d60934f2..113295eb 100644 --- a/examples/ezmsg_configs.py +++ b/examples/ezmsg_configs.py @@ -239,5 +239,5 @@ def network(self) -> ez.NetworkDefinition: ] for system in test_systems: - ez.logger.info(f"Testing { system.__name__ }") + ez.logger.info(f"Testing {system.__name__}") ez.run(system()) diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index 2d1bf3f6..a5c5c772 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -192,8 +192,8 @@ def process_components(self): ez.run( SYSTEM=system, - # connections = [ - # ( system.PING.OUTPUT, 'PING_OUTPUT' ), - # ( 'FOO_SUB', system.FOOSUB.INPUT ) - # ] + connections=[ + # Make PING.OUTPUT available on a topic ezmsg_attach.py + (system.PING.OUTPUT, "GLOBAL_PING_TOPIC"), + ], ) diff --git a/examples/lowlevel_api.py b/examples/lowlevel_api.py new file mode 100644 index 00000000..9bc5d586 --- /dev/null +++ b/examples/lowlevel_api.py @@ -0,0 +1,119 @@ +import asyncio + +import ezmsg.core as ez + +PORT = 12345 +MAX_COUNT = 100 +TOPIC = "/TEST" + + +async def handle_pub(pub: ez.Publisher) -> None: + print("Publisher Task Launched") + + count = 0 + + while True: + await pub.broadcast(f"{count=}") + await asyncio.sleep(0.1) + count += 1 + if count >= MAX_COUNT: + break + + print("Publisher Task Concluded") + + +async def handle_sub(sub: ez.Subscriber) -> None: + print("Subscriber Task Launched") + + rx_count = 0 + while True: + async with sub.recv_zero_copy() as msg: + # Uncomment if you want to witness backpressure! + # await asyncio.sleep(0.15) + print(msg) + + rx_count += 1 + if rx_count >= MAX_COUNT: + break + + print("Subscriber Task Concluded") + + +async def host(host: str = "127.0.0.1"): + # Manually create a GraphServer + server = ez.GraphServer() + server.start((host, PORT)) + + print(f"Created GraphServer @ {server.address}") + + try: + test_pub = await ez.Publisher.create(TOPIC, (host, PORT), host=host) + test_sub1 = await ez.Subscriber.create(TOPIC, (host, PORT)) + test_sub2 = await ez.Subscriber.create(TOPIC, (host, PORT)) + + await asyncio.sleep(1.0) + + pub_task = asyncio.Task(handle_pub(test_pub)) + sub_task_1 = asyncio.Task(handle_sub(test_sub1)) + sub_task_2 = asyncio.Task(handle_sub(test_sub2)) + + await asyncio.wait([pub_task, sub_task_1, sub_task_2]) + + test_pub.close() + test_sub1.close() + test_sub2.close() + + for future in asyncio.as_completed( + [ + test_pub.wait_closed(), + test_sub1.wait_closed(), + test_sub2.wait_closed(), + ] + ): + await future + + finally: + server.stop() + + print("Done") + + +async def attach_client(host: str = "127.0.0.1"): + + sub = await ez.Subscriber.create(TOPIC, (host, PORT)) + + try: + while True: + async with sub.recv_zero_copy() as msg: + # Uncomment if you want to see EXTREME backpressure! + # await asyncio.sleep(1.0) + print(msg) + + except asyncio.CancelledError: + pass + + finally: + sub.close() + await sub.wait_closed() + print("Detached") + + +if __name__ == "__main__": + from dataclasses import dataclass + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--attach", action="store_true", help="attach to running graph") + parser.add_argument("--host", default="0.0.0.0", help="hostname for graphserver") + + @dataclass + class Args: + attach: bool + host: str + + args = Args(**vars(parser.parse_args())) + + if args.attach: + asyncio.run(attach_client(host=args.host)) + else: + asyncio.run(host(host=args.host)) diff --git a/pyproject.toml b/pyproject.toml index cbe88ce7..c9449919 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dev = [ {include-group = "lint"}, {include-group = "test"}, "pre-commit>=4.3.0", + "viztracer>=1.0.4", ] lint = [ "flake8>=7.3.0", @@ -35,6 +36,7 @@ test = [ "pytest-asyncio>=1.1.0", "pytest-cov>=6.2.1", "xarray>=2025.6.1", + "psutil>=7.1.0", ] docs = [ {include-group = "axisarray"}, @@ -51,6 +53,7 @@ axisarray = [ [project.scripts] ezmsg = "ezmsg.core.command:cmdline" +ezmsg-perf = "ezmsg.util.perf.command:command" [project.optional-dependencies] axisarray = [ diff --git a/src/ezmsg/core/__init__.py b/src/ezmsg/core/__init__.py index e3ada5cd..fc361531 100644 --- a/src/ezmsg/core/__init__.py +++ b/src/ezmsg/core/__init__.py @@ -24,6 +24,8 @@ "GraphServer", "GraphContext", "run_command", + "Publisher", + "Subscriber", # All following are deprecated "System", "run_system", @@ -42,6 +44,8 @@ from .graphserver import GraphServer from .graphcontext import GraphContext from .command import run_command +from .pubclient import Publisher +from .subclient import Subscriber # Following imports are deprecated from .backend import run_system diff --git a/src/ezmsg/core/addressable.py b/src/ezmsg/core/addressable.py index e74d2731..9dd15f82 100644 --- a/src/ezmsg/core/addressable.py +++ b/src/ezmsg/core/addressable.py @@ -4,17 +4,18 @@ class Addressable: """ Base class for objects that can be addressed within the ezmsg system. - - Addressable objects have a hierarchical address structure consisting of + + Addressable objects have a hierarchical address structure consisting of a location path and a name, similar to a filesystem path. """ + _name: str | None _location: list[str] | None def __init__(self) -> None: """ Initialize an Addressable object. - + The name and location are initially None and must be set before the object can be properly addressed. This is achieved through the ``_set_name()`` and ``_set_location()`` methods. @@ -38,7 +39,7 @@ def _set_location(self, location: list[str] | None = None): def name(self) -> str: """ Get the name of this addressable object. - + :return: The object's name :rtype: str :raises AssertionError: If name has not been set @@ -51,7 +52,7 @@ def name(self) -> str: def location(self) -> list[str]: """ Get the location path of this addressable object. - + :return: List of path components representing the object's location :rtype: list[str] :raises AssertionError: If location has not been set @@ -64,10 +65,10 @@ def location(self) -> list[str]: def address(self) -> str: """ Get the full address of this object. - + The address is constructed by joining the location path and name with forward slashes, similar to a filesystem path. - + :return: The full address string :rtype: str """ diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 2d988d33..09b35bd6 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -32,47 +32,56 @@ class ExecutionContext: - processes: list[BackendProcess] + _process_units: list[list[Unit]] + _processes: list[BackendProcess] | None + term_ev: EventType start_barrier: BarrierType connections: list[tuple[str, str]] def __init__( self, - processes: list[list[Unit]], - graph_service: GraphService, + process_units: list[list[Unit]], connections: list[tuple[str, str]] = [], - backend_process: type[BackendProcess] = DefaultBackendProcess, ) -> None: - if not processes: - raise ValueError("Cannot create an execution context for zero processes") - self.connections = connections + self._process_units = process_units + self._processes = None self.term_ev = Event() - self.start_barrier = Barrier(len(processes)) - self.stop_barrier = Barrier(len(processes)) + self.start_barrier = Barrier(len(process_units)) + self.stop_barrier = Barrier(len(process_units)) - self.processes = [ + def create_processes( + self, + graph_address: AddressType | None, + backend_process: type[BackendProcess] = DefaultBackendProcess, + ) -> None: + self._processes = [ backend_process( process_units, self.term_ev, self.start_barrier, self.stop_barrier, - graph_service, + graph_address, ) - for process_units in processes + for process_units in self._process_units ] + @property + def processes(self) -> list[BackendProcess]: + if self._processes is None: + raise ValueError("ExecutionContext has not initialized processes") + else: + return self._processes + @classmethod def setup( cls, components: Mapping[str, Component], - graph_service: GraphService, root_name: str | None = None, connections: NetworkDefinition | None = None, process_components: AbstractCollection[Component] | None = None, - backend_process: type[BackendProcess] = DefaultBackendProcess, force_single_process: bool = False, ) -> "ExecutionContext | None": graph_connections: list[tuple[str, str]] = [] @@ -133,16 +142,14 @@ def configure_collections(comp: Component): if force_single_process: processes = [[u for pu in processes for u in pu]] - try: - return cls( - processes, - graph_service, - graph_connections, - backend_process, - ) - except ValueError: + if not processes: return None + return cls( + processes, + graph_connections, + ) + def run_system( system: Collection, @@ -152,8 +159,8 @@ def run_system( ) -> None: """ Deprecated function for running a system (Collection). - - .. deprecated:: + + .. deprecated:: Use :func:`run` instead to run any component (unit, collection). :param system: The collection to run @@ -184,9 +191,9 @@ def run( This is the main entry point for running ezmsg applications. It sets up the execution environment, initializes components, and manages the message-passing - infrastructure. + infrastructure. - On initialization, ezmsg will call ``initialize()`` for each :obj:`Unit` and + On initialization, ezmsg will call ``initialize()`` for each :obj:`Unit` and ``configure()`` for each :obj:`Collection`, if defined. On initialization, ezmsg will create a directed acyclic graph using the contents of ``connections``. @@ -210,16 +217,15 @@ def run( :param components_kwargs: Additional components specified as keyword arguments :type components_kwargs: Component - .. note:: + .. note:: Since jupyter notebooks run in a single process, you must set `force_single_process=True`. - + .. note:: The old method :obj:`run_system` has been deprecated and uses ``run()`` instead. """ os.environ["EZMSG_PROFILER"] = profiler_log_name or "ezprofiler.log" # FIXME: This function is the last major re-implementation needed to make this # codebase more maintainable. - graph_service = GraphService(graph_address) if components is not None and isinstance(components, Component): components = {"SYSTEM": components} @@ -233,11 +239,9 @@ def run( with new_threaded_event_loop() as loop: execution_context = ExecutionContext.setup( components, - graph_service, root_name, connections, process_components, - backend_process, force_single_process, ) @@ -246,7 +250,7 @@ def run( # FIXME: When done this way, we don't exit the graph_context on exception async def create_graph_context() -> GraphContext: - return await GraphContext(graph_service).__aenter__() + return await GraphContext(graph_address).__aenter__() # FIXME: This sort of stuff should all be done in a separate async function... # Done this way, its ugly as hell and opens us up to a lot of issues with @@ -255,6 +259,18 @@ async def create_graph_context() -> GraphContext: create_graph_context(), loop ).result() + if graph_context._graph_server is None: + address = graph_context.graph_address + if address is None: + address = GraphService.default_address() + logger.info(f"Connected to GraphServer @ {address}") + else: + logger.info(f"Spawned GraphServer @ {graph_context.graph_address}") + + execution_context.create_processes( + graph_address=graph_context.graph_address, backend_process=backend_process + ) + async def cleanup_graph() -> None: await graph_context.__aexit__(None, None, None) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index a1a66d47..ba93ead5 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -6,6 +6,7 @@ import traceback import threading +from abc import abstractmethod from collections import defaultdict from collections.abc import Callable, Coroutine, Generator, Sequence from functools import wraps, partial @@ -21,11 +22,9 @@ from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR from .graphcontext import GraphContext -from .graphserver import GraphService from .pubclient import Publisher from .subclient import Subscriber -from .messagecache import MessageCache -from abc import abstractmethod +from .netprotocol import AddressType logger = logging.getLogger("ezmsg") @@ -34,7 +33,7 @@ class Complete(Exception): """ A type of Exception raised by Unit methods, which signals to ezmsg that the function can be shut down gracefully. - + If all functions in all Units raise Complete, the entire pipeline will terminate execution. This exception is used to signal normal completion of processing tasks. @@ -49,7 +48,7 @@ class Complete(Exception): class NormalTermination(Exception): """ A type of Exception which signals to ezmsg that the pipeline can be shut down gracefully. - + This exception is used to indicate that the system should terminate normally, typically when all processing is complete or when a graceful shutdown is requested. @@ -63,16 +62,17 @@ class NormalTermination(Exception): class BackendProcess(Process): """ Abstract base class for backend processes that execute Units. - + BackendProcess manages the execution of Units in a separate process, handling initialization, coordination with other processes via barriers, and cleanup operations. """ + units: list[Unit] term_ev: EventType start_barrier: BarrierType stop_barrier: BarrierType - graph_service: GraphService + graph_address: AddressType | None def __init__( self, @@ -80,11 +80,11 @@ def __init__( term_ev: EventType, start_barrier: BarrierType, stop_barrier: BarrierType, - graph_service: GraphService, + graph_address: AddressType | None, ) -> None: """ Initialize the backend process. - + :param units: List of Units to execute in this process. :type units: list[Unit] :param term_ev: Event for coordinated termination. @@ -103,13 +103,13 @@ def __init__( self.term_ev = term_ev self.start_barrier = start_barrier self.stop_barrier = stop_barrier - self.graph_service = graph_service + self.graph_address = graph_address self.task_finished_ev: threading.Event | None = None def run(self) -> None: """ Main entry point for the process execution. - + Sets up the event loop and handles the main processing logic with proper exception handling for interrupts. """ @@ -124,10 +124,10 @@ def run(self) -> None: def process(self, loop: asyncio.AbstractEventLoop) -> None: """ Abstract method for implementing the main processing logic. - + Subclasses must implement this method to define how Units are executed within the event loop. - + :param loop: The asyncio event loop for this process. :type loop: asyncio.AbstractEventLoop :raises NotImplementedError: Must be implemented by subclasses. @@ -138,16 +138,17 @@ def process(self, loop: asyncio.AbstractEventLoop) -> None: class DefaultBackendProcess(BackendProcess): """ Default implementation of BackendProcess for executing Units. - + This class provides the standard execution model for ezmsg Units, handling publishers, subscribers, and the complete Unit lifecycle including initialization, execution, and shutdown. """ + pubs: dict[str, Publisher] def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None - context = GraphContext(self.graph_service) + context = GraphContext(self.graph_address) coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() try: @@ -229,11 +230,9 @@ async def setup_state(): logger.debug("Waiting at start barrier!") self.start_barrier.wait() - threads = [ - loop.run_in_executor(None, thread_fn, unit) - for unit in self.units - for thread_fn in unit.threads.values() - ] + for unit in self.units: + for thread_fn in unit.threads.values(): + loop.run_in_executor(None, thread_fn, unit) for pub in self.pubs.values(): pub.resume() @@ -306,8 +305,8 @@ async def shutdown_units() -> None: asyncio.run_coroutine_threadsafe(shutdown_units(), loop=loop).result() - for cache in MessageCache.values(): - cache.clear() + # for cache in MessageCache.values(): + # cache.clear() asyncio.run_coroutine_threadsafe(context.revert(), loop=loop).result() @@ -397,11 +396,11 @@ async def handle_subscriber( ): """ Handle incoming messages from a subscriber and distribute to callables. - + Continuously receives messages from the subscriber and calls all registered callables with each message. Removes callables that raise Complete or NormalTermination exceptions. - + :param sub: Subscriber to receive messages from. :type sub: Subscriber :param callables: Set of async callables to invoke with messages. @@ -429,10 +428,10 @@ async def handle_subscriber( def run_loop(loop: asyncio.AbstractEventLoop): """ Run an asyncio event loop in the current thread. - + Sets the event loop for the current thread and runs it forever until interrupted or stopped. - + :param loop: The asyncio event loop to run. :type loop: asyncio.AbstractEventLoop """ @@ -449,10 +448,10 @@ def new_threaded_event_loop( ) -> Generator[asyncio.AbstractEventLoop, None, None]: """ Create a new asyncio event loop running in a separate thread. - + Provides a context manager that yields an event loop running in its own thread, allowing async operations to be run from synchronous code. - + :param ev: Optional event to signal when the loop is ready. :type ev: threading.Event | None :return: Context manager yielding the event loop. diff --git a/src/ezmsg/core/backpressure.py b/src/ezmsg/core/backpressure.py index 6a91759e..56c0e7cf 100644 --- a/src/ezmsg/core/backpressure.py +++ b/src/ezmsg/core/backpressure.py @@ -8,17 +8,18 @@ class BufferLease: """ Manages leases for a single buffer in the backpressure system. - + A BufferLease tracks which clients have active leases on a buffer and provides synchronization when the buffer becomes empty. """ + leases: set[UUID] empty: asyncio.Event def __init__(self) -> None: """ Initialize a new BufferLease. - + The buffer starts empty with no active leases. """ self.leases = set() @@ -28,7 +29,7 @@ def __init__(self) -> None: def add(self, uuid: UUID) -> None: """ Add a lease for the specified client. - + :param uuid: Unique identifier for the client :type uuid: UUID """ @@ -38,7 +39,7 @@ def add(self, uuid: UUID) -> None: def remove(self, uuid: UUID) -> None: """ Remove a lease for the specified client. - + :param uuid: Unique identifier for the client :type uuid: UUID """ @@ -49,7 +50,7 @@ def remove(self, uuid: UUID) -> None: async def wait(self) -> Literal[True]: """ Wait until the buffer becomes empty (no active leases). - + :return: Always returns True when the buffer is empty :rtype: Literal[True] """ @@ -59,7 +60,7 @@ async def wait(self) -> Literal[True]: def is_empty(self) -> bool: """ Check if the buffer has no active leases. - + :return: True if no leases are active, False otherwise :rtype: bool """ @@ -69,14 +70,15 @@ def is_empty(self) -> bool: class Backpressure: """ Manages backpressure for multiple buffers in a message passing system. - + The Backpressure class coordinates access to multiple buffers, tracking which buffers are in use and providing synchronization mechanisms to manage flow control between publishers and subscribers. - + :param num_buffers: Number of buffers to manage :type num_buffers: int """ + buffers: list[BufferLease] empty: asyncio.Event pressure: int @@ -84,7 +86,7 @@ class Backpressure: def __init__(self, num_buffers: int) -> None: """ Initialize backpressure management for the specified number of buffers. - + :param num_buffers: Number of buffers to create and manage :type num_buffers: int """ @@ -97,7 +99,7 @@ def __init__(self, num_buffers: int) -> None: def is_empty(self) -> bool: """ Check if all buffers are empty (no active pressure). - + :return: True if no buffers have active leases, False otherwise :rtype: bool """ @@ -106,7 +108,7 @@ def is_empty(self) -> bool: def available(self, buf_idx: int) -> bool: """ Check if a specific buffer is available (has no active leases). - + :param buf_idx: Index of the buffer to check :type buf_idx: int :return: True if the buffer is available, False otherwise @@ -117,7 +119,7 @@ def available(self, buf_idx: int) -> bool: async def wait(self, buf_idx: int) -> None: """ Wait until a specific buffer becomes available. - + :param buf_idx: Index of the buffer to wait for :type buf_idx: int """ @@ -126,7 +128,7 @@ async def wait(self, buf_idx: int) -> None: def lease(self, uuid: UUID, buf_idx: int) -> None: """ Create a lease on a specific buffer for the given client. - + :param uuid: Unique identifier for the client :type uuid: UUID :param buf_idx: Index of the buffer to lease @@ -140,7 +142,7 @@ def lease(self, uuid: UUID, buf_idx: int) -> None: def _free(self, uuid: UUID, buf_idx: int) -> None: """ Internal method to free a lease on a specific buffer. - + :param uuid: Unique identifier for the client :type uuid: UUID :param buf_idx: Index of the buffer to free @@ -156,10 +158,10 @@ def _free(self, uuid: UUID, buf_idx: int) -> None: def free(self, uuid: UUID, buf_idx: int | None = None) -> None: """ Free leases for the given client. - + If buf_idx is specified, only frees the lease on that buffer. If buf_idx is None, frees leases on all buffers for this client. - + :param uuid: Unique identifier for the client :type uuid: UUID :param buf_idx: Optional buffer index to free, or None to free all @@ -177,7 +179,7 @@ def free(self, uuid: UUID, buf_idx: int | None = None) -> None: async def sync(self) -> Literal[True]: """ Wait until all buffers are empty (no backpressure). - + :return: Always returns True when all buffers are empty :rtype: Literal[True] """ diff --git a/src/ezmsg/core/channelmanager.py b/src/ezmsg/core/channelmanager.py new file mode 100644 index 00000000..42231d06 --- /dev/null +++ b/src/ezmsg/core/channelmanager.py @@ -0,0 +1,132 @@ +import logging + +from uuid import UUID + +from .messagechannel import Channel, NotificationQueue +from .backpressure import Backpressure +from .netprotocol import Address, AddressType, GRAPHSERVER_ADDR + +logger = logging.getLogger("ezmsg") + + +def _ensure_address(address: AddressType | None) -> Address: + if address is None: + return Address.from_string(GRAPHSERVER_ADDR) + + elif not isinstance(address, Address): + return Address(*address) + + return address + + +class ChannelManager: + """ + ChannelManager maintains a process-specific registry of Channels, tracks what clients + need access to what channels, and creates/deallocates Channels accordingly + """ + + _registry: dict[Address, dict[UUID, Channel]] + + def __init__(self): + default_address = Address.from_string(GRAPHSERVER_ADDR) + self._registry = {default_address: dict()} + + async def register( + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue, + graph_address: AddressType | None = None, + ) -> Channel: + """ + Acquire the channel associated with a particular publisher, creating it if necessary + + :param pub_id: The UUID associated with the publisher to acquire a channel for + :type pub_id: UUID + :param client_id: The UUID associated with the client interested in this channel, used for deallocation when the channel has no more registered clients + :type client_id: UUID + :param queue: An asyncio Queue for the interested client that will be populated with incoming message notifications + :type queue: asyncio.Queue[tuple[UUID, int]] + :param graph_address: The address to the GraphServer that the requested publisher is managed by + :type graph_address: AddressType | None + :return: A Channel for retreiving messages from the requested Publisher + :rtype: Channel + """ + return await self._register(pub_id, client_id, queue, graph_address, None) + + async def register_local_pub( + self, + pub_id: UUID, + local_backpressure: Backpressure, + graph_address: AddressType | None = None, + ) -> Channel: + """ + Register/create a channel for a publisher that is local to (in the same process as) the publisher in question. + Because this message channel is local to the publisher, it will directly push messages into the channel and the channel will directly manage the publisher's Backpressure without any telemetry or serialization. + .. note:: Since only a Publisher should be registering a channel this way, it will not need access to the messages it is publishing, hence it will not provide a queue. + + :param pub_id: The UUID associated with the publisher to acquire a channel for + :type pub_id: UUID + :param local_backpressure: The Backpressure object associated with the Publisher + :type local_backpressure: Backpressure + :param graph_address: The address to the GraphServer that the requested publisher is managed by + :type graph_address: AddressType | None + :return: A Channel that the Publisher can push messages to locally + :rtype: Channel + """ + return await self._register( + pub_id, pub_id, None, graph_address, local_backpressure + ) + + async def _register( + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue | None = None, + graph_address: AddressType | None = None, + local_backpressure: Backpressure | None = None, + ) -> Channel: + graph_address = _ensure_address(graph_address) + try: + channel = self._registry.get(graph_address, dict())[pub_id] + except KeyError: + channel = await Channel.create(pub_id, graph_address) + channels = self._registry.get(graph_address, dict()) + channels[pub_id] = channel + self._registry[graph_address] = channels + channel.register_client(client_id, queue, local_backpressure) + return channel + + async def unregister( + self, pub_id: UUID, client_id: UUID, graph_address: AddressType | None = None + ) -> None: + """ + Indicate to the ChannelManager that the client referred to by client_id no longer needs access to the Channel associated with the publisher referred to by pub_id. + If no clients need access to this channel, the channel will be closed and removed from the ChannelManager. + + :param pub_id: The UUID associated with the publisher to acquire a channel for + :type pub_id: UUID + :param client_id: The UUID associated with the client interested in this channel, used for deallocation when the channel has no more registered clients + :type client_id: UUID + :param graph_address: The address to the GraphServer that the requested publisher is managed by + :type graph_address: AddressType | None + """ + graph_address = _ensure_address(graph_address) + channel = self._registry.get(graph_address, dict())[pub_id] + channel.unregister_client(client_id) + + logger.debug( + f"unregistered {client_id} from {pub_id}; {len(channel.clients)} left" + ) + + if len(channel.clients) == 0: + registry = self._registry[graph_address] + del registry[pub_id] + + channel.close() + await channel.wait_closed() + + logger.debug(f"closed channel {pub_id}: no clients") + + +CHANNELS = ChannelManager() diff --git a/src/ezmsg/core/collection.py b/src/ezmsg/core/collection.py index 4ac6e8f5..dd8f665d 100644 --- a/src/ezmsg/core/collection.py +++ b/src/ezmsg/core/collection.py @@ -9,9 +9,7 @@ # Iterable of (output_stream, input_stream) pairs defining the network connections -NetworkDefinition = Iterable[ - tuple[Stream | str, Stream | str] -] +NetworkDefinition = Iterable[tuple[Stream | str, Stream | str]] class CollectionMeta(ComponentMeta): @@ -41,7 +39,7 @@ def __init__( class Collection(Component, metaclass=CollectionMeta): """ Connects :obj:`Units ` together by defining a graph which connects OutputStreams to InputStreams. - + Collections are composite components that contain and coordinate multiple Units, defining how they communicate through stream connections. @@ -59,8 +57,8 @@ def __init__(self, *args, settings: Settings | None = None, **kwargs): def configure(self) -> None: """ A lifecycle hook that runs when the Collection is instantiated. - - This is the best place to call ``Unit.apply_settings()`` on each member + + This is the best place to call ``Unit.apply_settings()`` on each member Unit of the Collection. Override this method to perform collection-specific configuration of child components. """ @@ -68,9 +66,9 @@ def configure(self) -> None: def network(self) -> NetworkDefinition: """ - Override this method and have the definition return a NetworkDefinition + Override this method and have the definition return a NetworkDefinition which defines how InputStreams and OutputStreams from member Units will be connected. - + The NetworkDefinition specifies the message routing between components by connecting output streams to input streams. @@ -81,10 +79,10 @@ def network(self) -> NetworkDefinition: def process_components(self) -> AbstractCollection[Component]: """ - Override this method and have the definition return a tuple which contains + Override this method and have the definition return a tuple which contains Units and Collections which should run in their own processes. - This method allows you to specify which components should be isolated + This method allows you to specify which components should be isolated in separate processes for performance or isolation requirements. :return: Collection of components that should run in separate processes diff --git a/src/ezmsg/core/command.py b/src/ezmsg/core/command.py index 6f9b766b..09ce2dc3 100644 --- a/src/ezmsg/core/command.py +++ b/src/ezmsg/core/command.py @@ -5,7 +5,6 @@ import logging import subprocess import sys -import typing import webbrowser import zlib @@ -25,7 +24,7 @@ def cmdline() -> None: """ Command-line interface for ezmsg core server management. - + Provides commands for starting, stopping, and managing ezmsg server processes including GraphServer and SHMServer, as well as utilities for graph visualization. @@ -107,7 +106,7 @@ async def run_command( ) -> None: """ Run an ezmsg command with the specified parameters. - + This function handles various ezmsg commands like 'serve', 'start', 'shutdown', etc. and manages the graph and shared memory services. @@ -184,7 +183,7 @@ async def run_command( def mm(graph: str, target="live") -> str: """ Generate a Mermaid visualization URL for the given graph. - + :param graph: Graph representation string to visualize. :type graph: str :param target: Target platform ('live' or 'ink'). diff --git a/src/ezmsg/core/component.py b/src/ezmsg/core/component.py index d9e42902..fd831b93 100644 --- a/src/ezmsg/core/component.py +++ b/src/ezmsg/core/component.py @@ -98,16 +98,16 @@ def __init__( class Component(Addressable, metaclass=ComponentMeta): """ Metaclass which :obj:`Unit` and :obj:`Collection` inherit from. - + The Component class provides the foundation for all components in the ezmsg framework, including Units and Collections. It manages settings, state, streams, and provides the basic infrastructure for message-passing components. - + :param settings: Optional settings object for component configuration :type settings: Settings | None .. note:: - + When creating ezmsg nodes, inherit directly from :obj:`Unit` or :obj:`Collection`. """ @@ -193,7 +193,7 @@ def _set_location(self, location: list[str] | None = None): def tasks(self) -> dict[str, Callable]: """ Get the dictionary of tasks for this component. - + :return: Dictionary mapping task names to their callable functions :rtype: dict[str, collections.abc.Callable] """ @@ -203,7 +203,7 @@ def tasks(self) -> dict[str, Callable]: def streams(self) -> dict[str, Stream]: """ Get the dictionary of streams for this component. - + :return: Dictionary mapping stream names to their Stream objects :rtype: dict[str, Stream] """ @@ -213,7 +213,7 @@ def streams(self) -> dict[str, Stream]: def components(self) -> dict[str, "Component"]: """ Get the dictionary of child components for this component. - + :return: Dictionary mapping component names to their Component objects :rtype: dict[str, Component] """ @@ -223,7 +223,7 @@ def components(self) -> dict[str, "Component"]: def main(self) -> Callable[..., None] | None: """ Get the main function for this component. - + :return: The main callable function, or None if not set :rtype: collections.abc.Callable[..., None] | None """ @@ -233,7 +233,7 @@ def main(self) -> Callable[..., None] | None: def threads(self) -> dict[str, Callable]: """ Get the dictionary of thread functions for this component. - + :return: Dictionary mapping thread names to their callable functions :rtype: dict[str, collections.abc.Callable] """ diff --git a/src/ezmsg/core/dag.py b/src/ezmsg/core/dag.py index 285958c4..d0c9a72a 100644 --- a/src/ezmsg/core/dag.py +++ b/src/ezmsg/core/dag.py @@ -3,13 +3,14 @@ from dataclasses import dataclass, field -class CyclicException(Exception): +class CyclicException(Exception): """ Exception raised when an operation would create a cycle in the DAG. - + This exception is raised when attempting to add an edge that would violate the acyclic property of the directed acyclic graph. """ + ... @@ -20,18 +21,19 @@ class CyclicException(Exception): class DAG: """ Directed Acyclic Graph implementation for managing dependencies and connections. - + The DAG class provides functionality to build and maintain a directed acyclic graph, which is used by ezmsg to manage message flow between components while ensuring no circular dependencies exist. """ + graph: GraphType = field(default_factory=lambda: defaultdict(set), init=False) @property def nodes(self) -> set[str]: """ Get all nodes in the graph. - + :return: Set of all node names in the graph :rtype: set[str] """ @@ -41,13 +43,13 @@ def nodes(self) -> set[str]: def invgraph(self) -> GraphType: """ Get the inverse (reversed) graph. - + Creates a graph where all edges are reversed. This is useful for finding upstream dependencies. - + :return: Inverted graph representation :rtype: GraphType - + .. note:: This is currently implemented inefficiently but is adequate for typical use cases. """ @@ -61,13 +63,13 @@ def invgraph(self) -> GraphType: def add_edge(self, from_node: str, to_node: str) -> None: """ Ensure an edge exists in the graph. - + Adds an edge from from_node to to_node. Does nothing if the edge already exists. If the edge would make the graph cyclic, raises CyclicException. - + :param from_node: Source node name :type from_node: str - :param to_node: Destination node name + :param to_node: Destination node name :type to_node: str :raises CyclicException: If adding the edge would create a cycle """ @@ -88,10 +90,10 @@ def add_edge(self, from_node: str, to_node: str) -> None: def remove_edge(self, from_node: str, to_node: str) -> None: """ Ensure an edge is not present in the graph. - + Removes an edge from from_node to to_node. Does nothing if the edge doesn't exist. Automatically prunes unconnected nodes after removal. - + :param from_node: Source node name :type from_node: str :param to_node: Destination node name @@ -103,9 +105,9 @@ def remove_edge(self, from_node: str, to_node: str) -> None: def downstream(self, from_node: str) -> list[str]: """ Get a list of downstream nodes (including from_node). - + Performs a breadth-first search to find all nodes reachable from the given node. - + :param from_node: Starting node name :type from_node: str :return: List of downstream node names including the starting node @@ -116,10 +118,10 @@ def downstream(self, from_node: str) -> list[str]: def upstream(self, from_node: str) -> list[str]: """ Get a list of upstream nodes (including from_node). - + Performs a breadth-first search on the inverted graph to find all nodes that can reach the given node. - + :param from_node: Starting node name :type from_node: str :return: List of upstream node names including the starting node @@ -139,9 +141,9 @@ def _prune(self) -> None: def _leaves(graph: GraphType) -> set[str]: """ Find leaf nodes in a graph. - + Leaf nodes are nodes that have no outgoing edges. - + :param graph: The graph to analyze :type graph: GraphType :return: Set of leaf node names @@ -153,9 +155,9 @@ def _leaves(graph: GraphType) -> set[str]: def _bfs(graph: GraphType, node: str) -> list[str]: """ Breadth-first search of a graph starting from a given node. - + Traverses the graph in breadth-first order to find all reachable nodes. - + :param graph: The graph to search :type graph: GraphType :param node: Starting node for the search diff --git a/src/ezmsg/core/graph_util.py b/src/ezmsg/core/graph_util.py index 832b44ef..f6807bc5 100644 --- a/src/ezmsg/core/graph_util.py +++ b/src/ezmsg/core/graph_util.py @@ -1,6 +1,5 @@ from collections import defaultdict from textwrap import indent -from collections import defaultdict from uuid import uuid4 @@ -13,11 +12,11 @@ def prune_graph_connections( ) -> tuple[GraphType | None, list[str] | None]: """ Remove nodes from the graph that are proxy_topics. - + Proxy topics are nodes that are both source and target nodes in the connections graph. This function removes them and connects their upstream sources directly to their downstream targets. - + :param graph_connections: Graph representing topic connections. :type graph_connections: GraphType :return: Tuple of (pruned_graph, proxy_topics_list). @@ -53,7 +52,7 @@ def _pipeline_levels( ) -> defaultdict: """ Compute hierarchy levels for pipeline components. - + In ezmsg, a pipeline is built with units/collections, with subcomponents that are either more units/collection or input/output streams. The graph of the connections are stored in a DAG (directed acyclic graph object), but the @@ -63,7 +62,7 @@ def _pipeline_levels( This function computes the level of each component in the hierarchy (not just the connection nodes) and returns a dictionary of the level of each pipeline parts, where 0 is the level of the connection (leaf) nodes. - + :param graph_connections: Graph representing component connections. :type graph_connections: GraphType :param level_separator: Character separating hierarchy levels. @@ -99,7 +98,7 @@ def _get_parent_node(node: str, level_separator: str = "/") -> str: class LeafNodeException(Exception): """ Raised when connection nodes are not leaf nodes in the pipeline hierarchy. - + This exception indicates that the graph contains connections at non-leaf levels of the component hierarchy, which violates expected structure. """ diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index e311b90c..f773569a 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -2,6 +2,7 @@ import logging import typing +from .netprotocol import AddressType from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber @@ -14,12 +15,12 @@ class GraphContext: """ GraphContext maintains a list of created publishers, subscribers, and connections in the graph. - + The GraphContext provides a managed environment for creating and tracking publishers, subscribers, and graph connections. When the context is no longer needed, it can revert changes in the graph which disconnects publishers and removes modifications that this context made. - + It also maintains a context manager that ensures the GraphServer is running. :param graph_service: Optional graph service instance to use @@ -33,95 +34,103 @@ class GraphContext: _clients: set[Publisher | Subscriber] _edges: set[tuple[str, str]] - _graph_service: GraphService + _graph_address: AddressType | None _graph_server: GraphServer | None def __init__( self, - graph_service: GraphService | None = None, + graph_address: AddressType | None = None, ) -> None: self._clients = set() self._edges = set() - self._graph_service = ( - graph_service if graph_service is not None else GraphService() - ) + self._graph_address = graph_address self._graph_server = None + @property + def graph_address(self) -> AddressType | None: + if self._graph_server is not None: + return self._graph_server.address + else: + return self._graph_address + async def publisher(self, topic: str, **kwargs) -> Publisher: """ Create a publisher for the specified topic. - + :param topic: The topic name to publish to :type topic: str :param kwargs: Additional keyword arguments for publisher configuration :return: A Publisher instance for the topic :rtype: Publisher """ - pub = await Publisher.create(topic, self._graph_service, **kwargs) + pub = await Publisher.create(topic, self.graph_address, **kwargs) + self._clients.add(pub) return pub async def subscriber(self, topic: str, **kwargs) -> Subscriber: """ Create a subscriber for the specified topic. - + :param topic: The topic name to subscribe to :type topic: str :param kwargs: Additional keyword arguments for subscriber configuration :return: A Subscriber instance for the topic :rtype: Subscriber """ - sub = await Subscriber.create(topic, self._graph_service, **kwargs) + sub = await Subscriber.create(topic, self.graph_address, **kwargs) + self._clients.add(sub) return sub async def connect(self, from_topic: str, to_topic: str) -> None: """ Connect two topics in the message graph. - + :param from_topic: The source topic name :type from_topic: str :param to_topic: The destination topic name :type to_topic: str """ - await self._graph_service.connect(from_topic, to_topic) + + await GraphService(self.graph_address).connect(from_topic, to_topic) self._edges.add((from_topic, to_topic)) async def disconnect(self, from_topic: str, to_topic: str) -> None: """ Disconnect two topics in the message graph. - + :param from_topic: The source topic name :type from_topic: str :param to_topic: The destination topic name :type to_topic: str """ - await self._graph_service.disconnect(from_topic, to_topic) + await GraphService(self.graph_address).disconnect(from_topic, to_topic) self._edges.discard((from_topic, to_topic)) async def sync(self, timeout: float | None = None) -> None: """ Synchronize with the graph server. - + :param timeout: Optional timeout for the sync operation :type timeout: float | None """ - await self._graph_service.sync(timeout) + await GraphService(self.graph_address).sync(timeout) async def pause(self) -> None: """ Pause message processing in the graph. """ - await self._graph_service.pause() + await GraphService(self.graph_address).pause() async def resume(self) -> None: """ Resume message processing in the graph. """ - await self._graph_service.resume() + await GraphService(self.graph_address).resume() async def _ensure_servers(self) -> None: - self._graph_server = await self._graph_service.ensure() + self._graph_server = await GraphService(self.graph_address).ensure() async def _shutdown_servers(self) -> None: if self._graph_server is not None: @@ -145,8 +154,8 @@ async def __aexit__( async def revert(self) -> None: """ Revert all changes made by this context. - - This method closes all clients (publishers and subscribers) created by this + + This method closes all publishers and subscribers created by this context and removes all edges that were added to the graph. It is automatically called when exiting the context manager. """ @@ -159,6 +168,6 @@ async def revert(self) -> None: for edge in self._edges: try: - await self._graph_service.disconnect(*edge) + await GraphService(self.graph_address).disconnect(*edge) except (ConnectionRefusedError, BrokenPipeError, ConnectionResetError) as e: logger.warn(f"Could not remove edge {edge} from GraphServer: {e}") diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 4d92035d..cb97b9ec 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -1,8 +1,12 @@ import asyncio import logging import pickle +import os +import socket +import threading from contextlib import suppress -from uuid import UUID, getnode, uuid1 +from uuid import UUID, uuid1 + from . import __version__ from .dag import DAG, CyclicException @@ -13,6 +17,7 @@ ClientInfo, SubscriberInfo, PublisherInfo, + ChannelInfo, AddressType, close_stream_writer, encode_str, @@ -22,18 +27,20 @@ GRAPHSERVER_ADDR_ENV, GRAPHSERVER_PORT_DEFAULT, DEFAULT_SHM_SIZE, + create_socket, + SERVER_PORT_START_ENV, + SERVER_PORT_START_DEFAULT, + DEFAULT_HOST, ) -from .server import ServiceManager, ThreadedAsyncServer -from .shm import SharedMemory, SHMContext, SHMInfo - +from .shm import SHMContext, SHMInfo logger = logging.getLogger("ezmsg") -class GraphServer(ThreadedAsyncServer): +class GraphServer(threading.Thread): """ Pub-sub directed acyclic graph (DAG) server. - + The GraphServer manages the message routing graph for ezmsg applications, handling publisher-subscriber relationships and maintaining the DAG structure. @@ -45,33 +52,75 @@ class GraphServer(ThreadedAsyncServer): and doesn't need to be instantiated directly by user code. """ + _server_up: threading.Event + _shutdown: threading.Event + + _sock: socket.socket + _loop: asyncio.AbstractEventLoop + graph: DAG clients: dict[UUID, ClientInfo] - node: int shms: dict[str, SHMInfo] _client_tasks: dict[UUID, "asyncio.Task[None]"] _command_lock: asyncio.Lock - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs) -> None: + super().__init__( + **{**kwargs, **dict(daemon=True, name=kwargs.get("name", "GraphServer"))} + ) + # threading events for lifecycle + self._server_up = threading.Event() + self._shutdown = threading.Event() + + # graph/server data self.graph = DAG() - self.clients = dict() - self._client_tasks = dict() + self.clients = {} + self._client_tasks = {} + self.shms = {} + + @property + def address(self) -> Address: + return Address(*self._sock.getsockname()) + + def start(self, address: AddressType | None = None) -> None: # type: ignore[override] + if address is not None: + self._sock = create_socket(*address) + else: + start_port = int( + os.environ.get(SERVER_PORT_START_ENV, SERVER_PORT_START_DEFAULT) + ) + self._sock = create_socket(start_port=start_port) - self.node = getnode() - self.shms = dict() + self._loop = asyncio.new_event_loop() + super().start() + self._server_up.wait() - async def setup(self) -> None: + def stop(self) -> None: + self._shutdown.set() + self.join() + + def run(self) -> None: + try: + asyncio.set_event_loop(self._loop) + with suppress(asyncio.CancelledError): + self._loop.run_until_complete(self._amain()) + finally: + self._loop.stop() + self._loop.close() + + async def _setup(self) -> None: self._command_lock = asyncio.Lock() - async def shutdown(self) -> None: - for task in self._client_tasks.values(): + async def _shutdown_async(self) -> None: + # Cancel client handler tasks and wait for them to end. + for task in list(self._client_tasks.values()): task.cancel() - with suppress(asyncio.CancelledError): - await task + with suppress(asyncio.CancelledError): + await asyncio.gather(*self._client_tasks.values(), return_exceptions=True) self._client_tasks.clear() + # Cancel SHM leases for info in self.shms.values(): for lease_task in list(info.leases): lease_task.cancel() @@ -79,20 +128,47 @@ async def shutdown(self) -> None: await lease_task info.leases.clear() + async def _amain(self) -> None: + """ + Start the asyncio server and serve forever until shutdown is requested. + """ + await self._setup() + + server = await asyncio.start_server(self.api, sock=self._sock) + + async def monitor_shutdown() -> None: + # Thread event -> wake in event loop + await self._loop.run_in_executor(None, self._shutdown.wait) + server.close() + await self._shutdown_async() + await server.wait_closed() + + monitor_task = self._loop.create_task(monitor_shutdown()) + self._server_up.set() + + try: + await server.serve_forever() + finally: + self._shutdown.set() + await monitor_task + async def api( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: try: - node = await read_int(reader) + # Always start communications by telling the client our ezmsg version + # This helps us get ahead of surprisingly common situations where + # the graph server and the graph clients are running different versions + # of ezmsg which could result in borked comms. writer.write(encode_str(__version__)) await writer.drain() req = await reader.read(1) - # Empty bytes object means EOF; Client disconnected - # This happens frequently when future clients are just pinging - # GraphServer to check if server is up if not req: + # Empty bytes object means EOF; Client disconnected + # This happens frequently when future clients are just pinging + # GraphServer to check if server is up await close_stream_writer(writer) return @@ -102,67 +178,118 @@ async def api( return elif req in [Command.SHM_CREATE.value, Command.SHM_ATTACH.value]: - info: Optional[SHMInfo] = None + shm_info: SHMInfo | None = None if req == Command.SHM_CREATE.value: num_buffers = await read_int(reader) buf_size = await read_int(reader) # Create segment - shm = SharedMemory(size=num_buffers * buf_size, create=True) - shm.buf[:] = b"0" * len(shm.buf) # Guarantee zeros - shm.buf[0:8] = uint64_to_bytes(num_buffers) - shm.buf[8:16] = uint64_to_bytes(buf_size) - info = SHMInfo(shm) - self.shms[shm.name] = info - logger.debug(f"created {shm.name}") + shm_info = SHMInfo.create(num_buffers, buf_size) + self.shms[shm_info.shm.name] = shm_info + logger.debug(f"created {shm_info.shm.name}") elif req == Command.SHM_ATTACH.value: shm_name = await read_str(reader) - info = self.shms.get(shm_name, None) + shm_info = self.shms.get(shm_name, None) - if info is None: + if shm_info is None: await close_stream_writer(writer) return writer.write(Command.COMPLETE.value) - writer.write(encode_str(info.shm.name)) - + writer.write(encode_str(shm_info.shm.name)) + + # NOTE: SHMContexts are like GraphClients in that their + # lifetime is bound to the lifetime of this connection + # but rather than the early return that we have with + # GraphClients, we just await the lease task. + # This may be a more readable pattern? + # NOTE: With the current shutdown pattern, we cancel + # the lease task then await it. When we cancel the lease + # task this resolves then it gets awaited in shutdown. with suppress(asyncio.CancelledError): - await info.lease(reader, writer) + await shm_info.lease(reader, writer) + + elif req == Command.CHANNEL.value: + channel_id = uuid1() + pub_id_str = await read_str(reader) + pub_id = UUID(pub_id_str) + + pub_addr = None + try: + pub_info = self.clients[pub_id] + if isinstance(pub_info, PublisherInfo): + # Advertise an address the channel should be able to resolve + port = pub_info.address.port + iface = pub_info.writer.transport.get_extra_info("peername")[0] + pub_addr = Address(iface, port) + else: + logger.warning(f"Connecting channel requested {type(pub_info)}") + + except KeyError: + logger.warning( + f"Connecting channel requested non-existent publisher {pub_id=}" + ) + + if pub_addr is None: + # FIXME: Channel should not be created; but it feels like we should + # have a better communication protocol to tell the channel what the + # error was and deliver a better experience from the client side. + # for now, drop connection + await close_stream_writer(writer) + return + + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(channel_id))) + pub_addr.to_stream(writer) + + info = ChannelInfo(channel_id, writer, pub_id) + self.clients[channel_id] = info + self._client_tasks[channel_id] = asyncio.create_task( + self._handle_client(channel_id, reader, writer) + ) + + # NOTE: Created a client, must return early + # to avoid closing writer + return else: # We only want to handle one command at a time async with self._command_lock: - if req in [Command.SUBSCRIBE.value, Command.PUBLISH.value]: - id = uuid1(node=node) - writer.write(encode_str(str(id))) - - pid = await read_int(reader) + if req in [ + Command.SUBSCRIBE.value, + Command.PUBLISH.value, + ]: + client_id = uuid1() topic = await read_str(reader) + writer.write(encode_str(str(client_id))) + if req == Command.SUBSCRIBE.value: - info = SubscriberInfo(id, writer, pid, topic) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) + info = SubscriberInfo(client_id, writer, topic) + self.clients[client_id] = info + self._client_tasks[client_id] = asyncio.create_task( + self._handle_client(client_id, reader, writer) ) - iface = writer.transport.get_extra_info("sockname")[0] - await self._notify_subscriber(info, iface) + + await self._notify_subscriber(info) elif req == Command.PUBLISH.value: address = await Address.from_stream(reader) - info = PublisherInfo(id, writer, pid, topic, address) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) + info = PublisherInfo(client_id, writer, topic, address) + self.clients[client_id] = info + self._client_tasks[client_id] = asyncio.create_task( + self._handle_client(client_id, reader, writer) ) - iface = writer.transport.get_extra_info("peername")[0] + for sub in self._downstream_subs(info.topic): - await self._notify_subscriber(sub, iface) + await self._notify_subscriber(sub) writer.write(Command.COMPLETE.value) - await writer.drain() + + # NOTE: Created a client, must return early + # to avoid closing writer return elif req in [Command.CONNECT.value, Command.DISCONNECT.value]: @@ -217,12 +344,32 @@ async def api( except (ConnectionResetError, BrokenPipeError): logger.debug("GraphServer connection fail mid-command") + # NOTE: This prevents code repetition for many graph server commands, but + # when we create GraphClients, their lifecycle is bound to the lifecycle of + # this connection. We do NOT want to close the stream writer if we have + # created a GraphClient, which requires an early return. Perhaps a different + # communication protocol could resolve this await close_stream_writer(writer) async def _handle_client( - self, id: UUID, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + self, + client_id: UUID, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, ) -> None: - logger.debug(f"Graph Server: Client connected: {id}") + """ + The lifecycle of a graph client is tied to the lifecycle of this TCP connection. + We always attempt to read from the reader, and if the connection is ever closed + from the other side, reader.read(1) returns empty, we break the loop, and delete + the client from our list of connected clients. Because we're always reading + from the client within this task, we have to handle all client responses within + this task. + + NOTE: _notify_subscriber requires a COMPLETE once the subscriber has reconfigured + its connections. ClientInfo.sync_writer gives us the mechanism by which to block + _notify_subscriber until the COMPLETE is received in this context. + """ + logger.debug(f"Graph Server: Client connected: {client_id}") try: while True: @@ -232,33 +379,26 @@ async def _handle_client( break if req == Command.COMPLETE.value: - self.clients[id].set_sync() + self.clients[client_id].set_sync() except (ConnectionResetError, BrokenPipeError) as e: - logger.debug(f"Client {id} disconnected from GraphServer: {e}") + logger.debug(f"Client {client_id} disconnected from GraphServer: {e}") finally: - self.clients[id].set_sync() - del self.clients[id] + # Ensure any waiter on this client unblocks + # with suppress(Exception): + self.clients[client_id].set_sync() + self.clients.pop(client_id, None) await close_stream_writer(writer) - async def _notify_subscriber( - self, sub: SubscriberInfo, iface: str | None = None - ) -> None: + async def _notify_subscriber(self, sub: SubscriberInfo) -> None: try: - notification = [] - for pub in self._upstream_pubs(sub.topic): - address = pub.address - if address is not None: - if iface is not None: - notification.append(f"{str(pub.id)}@{iface}:{address.port}") - else: - notification.append( - f"{str(pub.id)}@{address.host}:{address.port}" - ) + pub_ids = [str(pub.id) for pub in self._upstream_pubs(sub.topic)] + # Update requires us to read a 'COMPLETE' + # This cannot be done from this context async with sub.sync_writer() as writer: - notify_str = ",".join(notification) + notify_str = ",".join(pub_ids) writer.write(Command.UPDATE.value) writer.write(encode_str(notify_str)) @@ -275,6 +415,9 @@ def _subscribers(self) -> list[SubscriberInfo]: info for info in self.clients.values() if isinstance(info, SubscriberInfo) ] + def _channels(self) -> list[ChannelInfo]: + return [info for info in self.clients.values() if isinstance(info, ChannelInfo)] + def _upstream_pubs(self, topic: str) -> list[PublisherInfo]: """Given a topic, return a set of all publisher IDs upstream of that topic""" upstream_topics = self.graph.upstream(topic) @@ -286,19 +429,56 @@ def _downstream_subs(self, topic: str) -> list[SubscriberInfo]: return [sub for sub in self._subscribers() if sub.topic in downstream_topics] -class GraphService(ServiceManager[GraphServer]): +class GraphService: ADDR_ENV = GRAPHSERVER_ADDR_ENV PORT_DEFAULT = GRAPHSERVER_PORT_DEFAULT + _address: Address | None + def __init__(self, address: AddressType | None = None) -> None: - super().__init__(GraphServer, address) + self._address = Address(*address) if address is not None else None + + @classmethod + def default_address(cls) -> Address: + address_str = os.environ.get(cls.ADDR_ENV, f"{DEFAULT_HOST}:{cls.PORT_DEFAULT}") + return Address.from_string(address_str) + + @property + def address(self) -> Address: + return self._address if self._address is not None else self.default_address() + + def create_server(self) -> GraphServer: + server = GraphServer(name="GraphServer") + server.start(self._address) + self._address = server.address + return server + + async def ensure(self) -> GraphServer | None: + """ + Try connecting to an existing server. If none is listening and no explicit + address/environment is set, start one and return it. If an existing one is + found, return None. + """ + server = None + ensure_server = False + if self._address is None: + # Only auto-start if env var not forcing a location + ensure_server = self.ADDR_ENV not in os.environ + + try: + reader, writer = await self.open_connection() + await close_stream_writer(writer) + except OSError as ref_e: + if not ensure_server: + raise ref_e + server = self.create_server() + + return server async def open_connection( self, ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: - reader, writer = await super().open_connection() - writer.write(uint64_to_bytes(getnode())) - await writer.drain() + reader, writer = await asyncio.open_connection(*(self.address)) server_version = await read_str(reader) if server_version != __version__: logger.warning( @@ -314,6 +494,7 @@ async def connect(self, from_topic: str, to_topic: str) -> None: await writer.drain() response = await reader.read(1) if response == Command.CYCLIC.value: + await close_stream_writer(writer) raise CyclicException await close_stream_writer(writer) @@ -354,7 +535,7 @@ async def dag(self, timeout: float | None = None) -> DAG: writer.write(Command.DAG.value) await writer.drain() await asyncio.sleep(1.0) - await reader.readexactly(1) + await reader.read(1) dag_num_bytes = await read_int(reader) dag_bytes = await reader.readexactly(dag_num_bytes) dag: DAG = pickle.loads(dag_bytes) @@ -400,7 +581,9 @@ async def get_formatted_graph( return formatted_graph async def create_shm( - self, num_buffers: int, buf_size: int = DEFAULT_SHM_SIZE + self, + num_buffers: int, + buf_size: int = DEFAULT_SHM_SIZE, ) -> SHMContext: reader, writer = await self.open_connection() writer.write(Command.SHM_CREATE.value) @@ -410,10 +593,11 @@ async def create_shm( response = await reader.read(1) if response != Command.COMPLETE.value: + await close_stream_writer(writer) raise ValueError("Error creating SHM segment") shm_name = await read_str(reader) - return SHMContext._create(shm_name, reader, writer) + return SHMContext.attach(shm_name, reader, writer) async def attach_shm(self, name: str) -> SHMContext: reader, writer = await self.open_connection() @@ -423,7 +607,8 @@ async def attach_shm(self, name: str) -> SHMContext: response = await reader.read(1) if response != Command.COMPLETE.value: + await close_stream_writer(writer) raise ValueError("Invalid SHM Name") shm_name = await read_str(reader) - return SHMContext._create(shm_name, reader, writer) + return SHMContext.attach(shm_name, reader, writer) diff --git a/src/ezmsg/core/message.py b/src/ezmsg/core/message.py index ca3225d0..fced71d6 100644 --- a/src/ezmsg/core/message.py +++ b/src/ezmsg/core/message.py @@ -25,14 +25,15 @@ def __new__( class Message(ABC, metaclass=MessageMeta): """ Deprecated base class for messages in the ezmsg framework. - + .. deprecated:: Message is deprecated. Use @dataclass decorators instead of inheriting - from ez.Message. For data arrays, use :obj:`ezmsg.util.messages.AxisArray`. - + from ez.Message. For data arrays, use :obj:`ezmsg.util.messages.AxisArray`. + .. note:: This class will issue a DeprecationWarning when instantiated. """ + def __init__(self): warnings.warn( "Message is deprecated. Replace ez.Message with @dataclass decorators", @@ -45,10 +46,9 @@ def __init__(self): class Flag: """ A message with no contents. - + Flag is used as a simple signal message that carries no data, typically used for synchronization or simple event notification. """ ... - diff --git a/src/ezmsg/core/messagecache.py b/src/ezmsg/core/messagecache.py index c246f3b6..f14cd2f7 100644 --- a/src/ezmsg/core/messagecache.py +++ b/src/ezmsg/core/messagecache.py @@ -1,15 +1,6 @@ -import logging +import typing -from uuid import UUID -from contextlib import contextmanager - -from .shm import SHMContext -from .messagemarshal import MessageMarshal, UninitializedMemory - -from collections.abc import Generator -from typing import Any - -logger = logging.getLogger("ezmsg") +from .messagemarshal import MessageMarshal class CacheMiss(Exception): @@ -19,21 +10,24 @@ class CacheMiss(Exception): This occurs when trying to retrieve a message that has been evicted from the cache or was never stored in the first place. """ - ... -class Cache: +class CacheEntry(typing.NamedTuple): + object: typing.Any + msg_id: int + context: typing.ContextManager | None + memory: memoryview | None + + +class MessageCache: """ - Shared-memory backed cache for objects. + Cache for memoryview-backed objects. - Provides a buffer cache that can store objects both in memory - and in shared memory buffers, enabling efficient message passing between + Provides a buffer cache that can store objects in memory + enabling efficient message passing between processes with automatic eviction based on buffer age. """ - - num_buffers: int - cache: list[Any] - cache_id: list[int | None] + _cache: list[CacheEntry | None] def __init__(self, num_buffers: int) -> None: """ @@ -42,92 +36,103 @@ def __init__(self, num_buffers: int) -> None: :param num_buffers: Number of cache buffers to maintain. :type num_buffers: int """ - self.num_buffers = num_buffers - self.cache_id = [None] * self.num_buffers - self.cache = [None] * self.num_buffers + self._cache = [None] * num_buffers + + def _buf_idx(self, msg_id: int) -> int: + return msg_id % len(self._cache) - def put(self, msg_id: int, msg: Any) -> None: + def __getitem__(self, msg_id: int) -> typing.Any: """ - Put an object into cache at the position determined by message ID. - - :param msg_id: Unique message identifier. + Get a cached object by msg_id + + :param msg_id: Message ID to retreive from cache :type msg_id: int - :param msg: The message object to cache. - :type msg: Any + :raises CacheMiss: If this msg_id does not exist in the cache. """ - buf_idx = msg_id % self.num_buffers - self.cache_id[buf_idx] = msg_id - self.cache[buf_idx] = msg + entry = self._cache[self._buf_idx(msg_id)] + if entry is None or entry.msg_id != msg_id: + raise CacheMiss + return entry.object - def push(self, msg_id: int, shm: SHMContext) -> None: + def keys(self) -> list[int]: """ - Push an object from cache into shared memory. - - If the message is not already in shared memory with the correct ID, - retrieves it from cache and serializes it to the shared memory buffer. - - :param msg_id: Message identifier to push. - :type msg_id: int - :param shm: Shared memory context to write to. - :type shm: SHMContext - :raises ValueError: If shared memory has wrong number of buffers. + Get a list of current cached msg_ids """ - if self.num_buffers != shm.num_buffers: - raise ValueError("shm has incorrect number of buffers") - - buf_idx = msg_id % self.num_buffers - with shm.buffer(buf_idx) as mem: - shm_msg_id = MessageMarshal.msg_id(mem) - if shm_msg_id != msg_id: - with self.get(msg_id) as obj: - MessageMarshal.to_mem(msg_id, obj, mem) - - @contextmanager - def get( - self, msg_id: int, shm: SHMContext | None = None - ) -> Generator[Any, None, None]: + return [entry.msg_id for entry in self._cache if entry is not None] + + def put_local(self, obj: typing.Any, msg_id: int) -> None: """ - Get object from cache; if not in cache and shm provided -- get from shm. - - Provides a context manager for safe access to cached messages. If the - message is not in memory cache, attempts to retrieve from shared memory. - - :param msg_id: Message identifier to retrieve. + Put an object with associated msg_id directly into cache + + :param obj: Object to put in cache. + :type obj: typing.Any + :param msg_id: ID associated with this message/object. :type msg_id: int - :param shm: Optional shared memory context as fallback. - :type shm: SHMContext | None - :return: Context manager yielding the requested message. - :rtype: Generator[Any, None, None] - :raises CacheMiss: If message not found in cache or shared memory. """ - - buf_idx = msg_id % self.num_buffers - if self.cache_id[buf_idx] == msg_id: - yield self.cache[buf_idx] - - else: - if shm is None: - raise CacheMiss - - with shm.buffer(buf_idx, readonly=True) as mem: - try: - if MessageMarshal.msg_id(mem) != msg_id: - raise CacheMiss - except UninitializedMemory: - raise CacheMiss - - with MessageMarshal.obj_from_mem(mem) as obj: - yield obj - - def clear(self): + self._put( + CacheEntry( + object=obj, + msg_id=msg_id, + context=None, + memory=None, + ) + ) + + def put_from_mem(self, mem: memoryview) -> None: """ - Clear all cached messages and identifiers. - - Resets all cache slots to None, effectively clearing the entire cache. + Reconstitute a message in mem and keep it in cache, releasing and + overwriting the existing slot in cache. + This method passes the lifecycle of the memoryview to the MessageCache + and the memoryview will be properly released by the cache with `free` + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + ctx = MessageMarshal.obj_from_mem(mem) + self._put( + CacheEntry( + object=ctx.__enter__(), + msg_id=MessageMarshal.msg_id(mem), + context=ctx, + memory=mem, + ) + ) + + def _put(self, entry: CacheEntry) -> None: + buf_idx = self._buf_idx(entry.msg_id) + self._release(buf_idx) + self._cache[buf_idx] = entry + + def _release(self, buf_idx: int) -> None: + entry = self._cache[buf_idx] + if entry is not None: + mem = entry.memory + ctx = entry.context + if ctx is not None: + ctx.__exit__(None, None, None) + del entry + self._cache[buf_idx] = None + if mem is not None: + mem.release() + + def release(self, msg_id: int) -> None: """ - self.cache_id = [None] * self.num_buffers - self.cache = [None] * self.num_buffers + Release memory for the entry associated with msg_id + :param msg_id: ID for the message to release. + :type msg_id: int + :raises CacheMiss: If requested msg_id is not in cache. + """ + buf_idx = self._buf_idx(msg_id) + entry = self._cache[buf_idx] + if entry is None or entry.msg_id != msg_id: + raise CacheMiss + self._release(buf_idx) -# FIXME: This should be made thread-safe in the future -MessageCache: dict[UUID, Cache] = dict() + def clear(self) -> None: + """ + Release all cached objects + """ + for i in range(len(self._cache)): + self._release(i) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py new file mode 100644 index 00000000..e016f101 --- /dev/null +++ b/src/ezmsg/core/messagechannel.py @@ -0,0 +1,373 @@ +import os +import asyncio +import typing +import logging + +from uuid import UUID +from contextlib import contextmanager, suppress + +from .shm import SHMContext +from .messagemarshal import MessageMarshal +from .backpressure import Backpressure +from .messagecache import MessageCache +from .graphserver import GraphService +from .netprotocol import ( + Command, + Address, + AddressType, + read_str, + read_int, + uint64_to_bytes, + encode_str, + close_stream_writer, +) + +logger = logging.getLogger("ezmsg") + + +NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] + + +class Channel: + """ + Channel is a "middle-man" that receives messages from a particular Publisher, + maintains the message in a MessageCache, and pushes notifications to interested + Subscribers in this process. + + Channel primarily exists to reduce redundant message serialization and telemetry. + + .. note:: + The Channel constructor should not be called directly, instead use Channel.create(...) + """ + + _SENTINEL = object() + + id: UUID + pub_id: UUID + pid: int + topic: str + + num_buffers: int + cache: MessageCache + shm: SHMContext | None + clients: dict[UUID, NotificationQueue | None] + backpressure: Backpressure + + _graph_task: asyncio.Task[None] + _pub_task: asyncio.Task[None] + _pub_writer: asyncio.StreamWriter + _graph_address: AddressType | None + _local_backpressure: Backpressure | None + + def __init__( + self, + id: UUID, + pub_id: UUID, + num_buffers: int, + shm: SHMContext | None, + graph_address: AddressType | None, + _guard = None, + ) -> None: + if _guard is not self._SENTINEL: + raise TypeError( + "Channel cannot be instantiated directly." + "Use 'await CHANNELS.register(...)' instead." + ) + + self.id = id + self.pub_id = pub_id + self.num_buffers = num_buffers + self.shm = shm + + self.cache = MessageCache(self.num_buffers) + self.backpressure = Backpressure(self.num_buffers) + self.clients = dict() + self._graph_address = graph_address + self._local_backpressure = None + + @classmethod + async def create( + cls, + pub_id: UUID, + graph_address: AddressType, + ) -> "Channel": + """ + Create a channel for a particular Publisher managed by a GraphServer at graph_address + + :param pub_id: The Publisher's UUID on the GraphServer + :type pub_id: UUID + :param graph_address: The address the GraphServer is hosted on. + :type graph_address: AddressType + :return: a configured and connected Channel for messages from the Publisher + :rtype: Channel + + .. note:: This is typically called by ChannelManager as interested Subscribers register. + """ + graph_service = GraphService(graph_address) + + graph_reader, graph_writer = await graph_service.open_connection() + graph_writer.write(Command.CHANNEL.value) + graph_writer.write(encode_str(str(pub_id))) + + response = await graph_reader.read(1) + if response != Command.COMPLETE.value: + # FIXME: This will happen if the channel requested connection + # to a non-existent (or non-publisher) UUID. Ideally GraphServer + # would tell us what happened rather than drop connection + raise ValueError(f"failed to create channel {pub_id=}") + + id_str = await read_str(graph_reader) + pub_address = await Address.from_stream(graph_reader) + + reader, writer = await asyncio.open_connection(*pub_address) + writer.write(Command.CHANNEL.value) + writer.write(encode_str(id_str)) + + shm = None + shm_name = await read_str(reader) + try: + shm = await graph_service.attach_shm(shm_name) + writer.write(Command.SHM_OK.value) + except (ValueError, OSError): + shm = None + writer.write(Command.SHM_ATTACH_FAILED.value) + writer.write(uint64_to_bytes(os.getpid())) + + result = await reader.read(1) + if result != Command.COMPLETE.value: + # NOTE: The only reason this would happen is if the + # publisher's writer is closed due to a crash or shutdown + raise ValueError(f"failed to create channel {pub_id=}") + + num_buffers = await read_int(reader) + assert num_buffers > 0, "publisher reports invalid num_buffers" + + chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address, _guard=cls._SENTINEL) + + chan._graph_task = asyncio.create_task( + chan._graph_connection(graph_reader, graph_writer), + name=f"chan-{chan.id}: _graph_connection", + ) + + chan._pub_writer = writer + chan._pub_task = asyncio.create_task( + chan._publisher_connection(reader), + name=f"chan-{chan.id}: _publisher_connection", + ) + + logger.debug(f"created channel {chan.id=} {pub_id=} {pub_address=}") + + return chan + + def close(self) -> None: + """ + Mark the Channel for shutdown and resource deallocation + """ + self._pub_task.cancel() + self._graph_task.cancel() + + async def wait_closed(self) -> None: + """ + Wait until the Channel has properly shutdown and its resources have been deallocated. + """ + with suppress(asyncio.CancelledError): + await self._pub_task + with suppress(asyncio.CancelledError): + await self._graph_task + if self.shm is not None: + await self.shm.wait_closed() + + async def _graph_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """ + The task that handles communication between the GraphServer and the Publisher. + """ + try: + while True: + cmd = await reader.read(1) + + if not cmd: + break + + else: + logger.warning( + f"Channel {self.id} rx unknown command from GraphServer: {cmd}" + ) + except (ConnectionResetError, BrokenPipeError): + logger.debug(f"Channel {self.id} lost connection to graph server") + + finally: + await close_stream_writer(writer) + + async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: + """ + The task that handles communication between the Channel and the Publisher it receives messages from. + """ + try: + while True: + msg = await reader.read(1) + + if not msg: + break + + msg_id = await read_int(reader) + buf_idx = msg_id % self.num_buffers + + if msg == Command.TX_SHM.value: + shm_name = await read_str(reader) + + if self.shm is not None and self.shm.name != shm_name: + shm_entries = self.cache.keys() + self.cache.clear() + self.shm.close() + await self.shm.wait_closed() + + try: + self.shm = await GraphService( + self._graph_address + ).attach_shm(shm_name) + except ValueError: + logger.info( + "Invalid SHM received from publisher; may be dead" + ) + raise + + for id in shm_entries: + self.cache.put_from_mem(self.shm[id % self.num_buffers]) + + assert self.shm is not None + assert MessageMarshal.msg_id(self.shm[buf_idx]) == msg_id + self.cache.put_from_mem(self.shm[buf_idx]) + + elif msg == Command.TX_TCP.value: + buf_size = await read_int(reader) + obj_bytes = await reader.readexactly(buf_size) + assert MessageMarshal.msg_id(obj_bytes) == msg_id + self.cache.put_from_mem(memoryview(obj_bytes).toreadonly()) + + else: + raise ValueError(f"unimplemented data telemetry: {msg}") + + if not self._notify_clients(msg_id): + # Nobody is listening; need to ack! + self.cache.release(msg_id) + self._acknowledge(msg_id) + + except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): + logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") + + finally: + self.cache.clear() + if self.shm is not None: + self.shm.close() + + await close_stream_writer(self._pub_writer) + + logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") + + def _notify_clients(self, msg_id: int) -> bool: + """notify interested clients and return true if any were notified""" + buf_idx = msg_id % self.num_buffers + for client_id, queue in self.clients.items(): + if queue is None: + continue # queue is none if this is the pub + self.backpressure.lease(client_id, buf_idx) + queue.put_nowait((self.pub_id, msg_id)) + return not self.backpressure.available(buf_idx) + + def put_local(self, msg_id: int, msg: typing.Any) -> None: + """ + Put a message DIRECTLY into cache and notify all clients. + .. note:: This command should ONLY be used by Publishers that are in the same process as this Channel. + """ + if self._local_backpressure is None: + raise ValueError( + "cannot put_local without access to publisher backpressure (is publisher in same process?)" + ) + + buf_idx = msg_id % self.num_buffers + if self._notify_clients(msg_id): + self.cache.put_local(msg, msg_id) + self._local_backpressure.lease(self.id, buf_idx) + + @contextmanager + def get( + self, msg_id: int, client_id: UUID + ) -> typing.Generator[typing.Any, None, None]: + """ + Get a message via a ContextManager + + :param msg_id: Message ID to retreive + :type msg_id: int + :param client_id: UUID of client retreiving this message for backpressure purposes + :type client_id: UUID + :raises CacheMiss: If this msg_id does not exist in the cache. + :return: A ContextManager for the message (type: Any) + :rtype: Generator[Any] + """ + + try: + yield self.cache[msg_id] + finally: + buf_idx = msg_id % self.num_buffers + self.backpressure.free(client_id, buf_idx) + if self.backpressure.buffers[buf_idx].is_empty: + self.cache.release(msg_id) + + # If pub is in same process as this channel, avoid TCP + if self._local_backpressure is not None: + self._local_backpressure.free(self.id, buf_idx) + else: + self._acknowledge(msg_id) + + def _acknowledge(self, msg_id: int) -> None: + try: + ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) + self._pub_writer.write(ack) + except (BrokenPipeError, ConnectionResetError): + logger.info(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") + + def register_client( + self, + client_id: UUID, + queue: NotificationQueue | None = None, + local_backpressure: Backpressure | None = None, + ) -> None: + """ + Register an interested client and provide a queue for incoming message notifications. + + :param client_id: The UUID of the subscribing client + :type client_id: UUID + :param queue: The notification queue for the subscribing client + :type queue: asyncio.Queue[tuple[UUID, int]] | None + :param local_backpressure: The backpressure object for the Publisher if it is in the same process + :type local_backpressure: Backpressure + """ + self.clients[client_id] = queue + if client_id == self.pub_id: + self._local_backpressure = local_backpressure + + def unregister_client(self, client_id: UUID) -> None: + """ + Unregister a subscribed client + + :param client_id: The UUID of the subscribing client + :type client_id: UUID + """ + queue = self.clients[client_id] + + # queue is only 'None' if this client is a local publisher + if queue is not None: + for _ in range(queue.qsize()): + pub_id, msg_id = queue.get_nowait() + if pub_id != self.pub_id: + queue.put_nowait((pub_id, msg_id)) + + self.backpressure.free(client_id) + + elif client_id == self.pub_id and self._local_backpressure is not None: + self._local_backpressure.free(self.id) + self._local_backpressure = None + + del self.clients[client_id] diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index ff114580..4621d642 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -8,20 +8,22 @@ _PREAMBLE = b"EZ" _PREAMBLE_LEN = len(_PREAMBLE) +NO_MESSAGE = _PREAMBLE + (b"\xff" * 8) + (b"\x00" * 8) class UndersizedMemory(Exception): """ Exception raised when target memory buffer is too small for serialization. - + Contains the required size needed to successfully serialize the object. """ + req_size: int def __init__(self, *args: object, req_size: int = 0) -> None: """ Initialize UndersizedMemory exception. - + :param args: Exception arguments. :param req_size: Required memory size in bytes. :type req_size: int @@ -30,13 +32,14 @@ def __init__(self, *args: object, req_size: int = 0) -> None: self.req_size = req_size -class UninitializedMemory(Exception): +class UninitializedMemory(Exception): """ Exception raised when attempting to read from uninitialized memory. - + This occurs when trying to deserialize from a memory buffer that doesn't contain valid serialized data with the expected preamble. """ + ... @@ -56,7 +59,7 @@ class Marshal: def to_mem(cls, msg_id: int, obj: Any, mem: memoryview) -> None: """ Serialize an object with message ID into a memory buffer. - + :param msg_id: Unique message identifier. :type msg_id: int :param obj: Object to serialize. @@ -71,43 +74,45 @@ def to_mem(cls, msg_id: int, obj: Any, mem: memoryview) -> None: if total_size >= len(mem): raise UndersizedMemory(req_size=total_size) - sidx = len(header) - mem[:sidx] = header[:] - for buf in buffers: - blen = len(buf) - mem[sidx : sidx + blen] = buf[:] - sidx += blen + cls._write(mem, header, buffers) @classmethod - def _assert_initialized(cls, mem: memoryview) -> None: - if mem[:_PREAMBLE_LEN] != _PREAMBLE: + def _write(cls, mem: memoryview, header: bytes, buffers: list[memoryview]): + sidx = len(header) + mem[:sidx] = header[:] + for buf in buffers: + blen = len(buf) + mem[sidx : sidx + blen] = buf[:] + sidx += blen + + @classmethod + def _assert_initialized(cls, raw: memoryview | bytes) -> None: + if raw[:_PREAMBLE_LEN] != _PREAMBLE: raise UninitializedMemory @classmethod - def msg_id(cls, mem: memoryview) -> int | None: + def msg_id(cls, raw: memoryview | bytes) -> int: """ - Get msg_id currently written in mem; if uninitialized, return None. - - :param mem: Memory buffer to read from. - :type mem: memoryview - :return: Message ID if memory is initialized, None otherwise. - :rtype: int | None + Get msg_id from a buffer; if uninitialized, return None. + + :param mem: buffer to read from. + :type mem: memoryview | bytes + :return: Message ID of encoded message + :rtype: int + :raises UninitializedMemory: If buffer is not initialized. """ - try: - cls._assert_initialized(mem) - return bytes_to_uint(mem[_PREAMBLE_LEN : _PREAMBLE_LEN + UINT64_SIZE]) - except UninitializedMemory: - return None + cls._assert_initialized(raw) + return bytes_to_uint(raw[_PREAMBLE_LEN : _PREAMBLE_LEN + UINT64_SIZE]) @classmethod @contextmanager def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: """ Deserialize an object from a memory buffer. - + Provides a context manager for safe access to deserialized objects with automatic cleanup of memory views. - + :param mem: Memory buffer containing serialized object. :type mem: memoryview :return: Context manager yielding the deserialized object. @@ -118,6 +123,9 @@ def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: sidx = _PREAMBLE_LEN + UINT64_SIZE num_buffers = bytes_to_uint(mem[sidx : sidx + UINT64_SIZE]) + if num_buffers == 0: + raise ValueError("invalid message in memory") + sidx += UINT64_SIZE buf_sizes = [0] * num_buffers for i in range(num_buffers): @@ -145,10 +153,10 @@ def serialize( ) -> Generator[tuple[int, bytes, list[memoryview]], None, None]: """ Serialize an object for network transmission. - + Creates a complete serialization package with header and buffers suitable for network transmission. - + :param msg_id: Unique message identifier. :type msg_id: int :param obj: Object to serialize. @@ -174,7 +182,7 @@ def serialize( def dump(obj: Any) -> list[memoryview]: """ Serialize an object to a list of memory buffers using pickle. - + :param obj: Object to serialize. :type obj: Any :return: List of memory views containing serialized data. @@ -189,7 +197,7 @@ def dump(obj: Any) -> list[memoryview]: def load(buffers: list[memoryview]) -> Any: """ Deserialize an object from a list of memory buffers using pickle. - + :param buffers: List of memory views containing serialized data. :type buffers: list[memoryview] :return: Deserialized object. @@ -201,16 +209,19 @@ def load(buffers: list[memoryview]) -> Any: def copy_obj(cls, from_mem: memoryview, to_mem: memoryview) -> None: """ Copy obj in from_mem (if initialized) to to_mem. - + :param from_mem: Source memory buffer containing serialized object. :type from_mem: memoryview :param to_mem: Target memory buffer for copying. :type to_mem: memoryview + :raises UninitializedMemory: If from_mem buffer is not properly initialized. """ msg_id = cls.msg_id(from_mem) - if msg_id is not None: - with MessageMarshal.obj_from_mem(from_mem) as obj: - MessageMarshal.to_mem(msg_id, obj, to_mem) + with MessageMarshal.obj_from_mem(from_mem) as obj: + MessageMarshal.to_mem(msg_id, obj, to_mem) +# If some other byte-level representation is desired, you can just +# monkeypatch the module at runtime with a different Marhsal subclass +# TODO: This could also be done with environment variables MessageMarshal = Marshal diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 3114c853..ee1d903c 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -3,6 +3,7 @@ import socket import typing import enum +import os from uuid import UUID from dataclasses import field, dataclass @@ -27,14 +28,19 @@ PUBLISHER_START_PORT_ENV = "EZMSG_PUBLISHER_PORT_START" PUBLISHER_START_PORT_DEFAULT = 25980 +GRAPHSERVER_ADDR = os.environ.get( + GRAPHSERVER_ADDR_ENV, f"{DEFAULT_HOST}:{GRAPHSERVER_PORT_DEFAULT}" +) + class Address(typing.NamedTuple): """ Network address representation with host and port. - + Provides utility methods for address parsing, serialization, and socket binding operations. """ + host: str port: int @@ -42,7 +48,7 @@ class Address(typing.NamedTuple): async def from_stream(cls, reader: asyncio.StreamReader) -> "Address": """ Read an Address from an async stream. - + :param reader: Stream reader to read address string from. :type reader: asyncio.StreamReader :return: Parsed Address instance. @@ -55,7 +61,7 @@ async def from_stream(cls, reader: asyncio.StreamReader) -> "Address": def from_string(cls, address: str) -> "Address": """ Parse an Address from a string representation. - + :param address: Address string in "host:port" format. :type address: str :return: Parsed Address instance. @@ -67,7 +73,7 @@ def from_string(cls, address: str) -> "Address": def to_stream(self, writer: asyncio.StreamWriter) -> None: """ Write this address to an async stream. - + :param writer: Stream writer to send address string to. :type writer: asyncio.StreamWriter """ @@ -76,7 +82,7 @@ def to_stream(self, writer: asyncio.StreamWriter) -> None: def bind_socket(self) -> socket.socket: """ Create and bind a socket to this address. - + :return: Socket bound to this address. :rtype: socket.socket :raises IOError: If no free ports are available. @@ -94,14 +100,13 @@ def __str__(self): class ClientInfo: """ Base information for client connections. - + Tracks client identification, communication writer, and provides synchronized access to the writer for thread-safe operations. """ + id: UUID writer: asyncio.StreamWriter - pid: int - topic: str _pending: asyncio.Event = field(default_factory=asyncio.Event, init=False) @@ -115,17 +120,16 @@ def set_sync(self) -> None: async def sync_writer(self) -> AsyncGenerator[asyncio.StreamWriter, None]: """ Get synchronized access to the writer. - + Ensures thread-safe access to the stream writer by coordinating access through an asyncio Event mechanism. - + :return: Context manager yielding the synchronized writer. :rtype: collections.abc.AsyncGenerator[asyncio.StreamWriter, None] """ await self._pending.wait() try: yield self.writer - await self.writer.drain() self._pending.clear() await self._pending.wait() finally: @@ -136,24 +140,35 @@ async def sync_writer(self) -> AsyncGenerator[asyncio.StreamWriter, None]: class PublisherInfo(ClientInfo): """ Publisher-specific client information. - Extends ClientInfo with the publisher's network address. """ + + topic: str address: Address @dataclass class SubscriberInfo(ClientInfo): """ - Subscriber-specific client information. + Subscriber-specific client information. """ - shm_access: bool = False + + topic: str + + +@dataclass +class ChannelInfo(ClientInfo): + """ + Channel-specific client information. + """ + + pub_id: UUID def uint64_to_bytes(i: int) -> bytes: """ Convert a 64-bit unsigned integer to bytes. - + :param i: Integer value to convert. :type i: int :return: Byte representation in little-endian format. @@ -165,7 +180,7 @@ def uint64_to_bytes(i: int) -> bytes: def bytes_to_uint(b: bytes) -> int: """ Convert bytes to a 64-bit unsigned integer. - + :param b: Byte data to convert. :type b: bytes :return: Integer value decoded from little-endian bytes. @@ -177,7 +192,7 @@ def bytes_to_uint(b: bytes) -> int: def encode_str(string: str) -> bytes: """ Encode a string with length prefix for network transmission. - + :param string: String to encode. :type string: str :return: Length-prefixed UTF-8 encoded bytes. @@ -191,7 +206,7 @@ def encode_str(string: str) -> bytes: async def read_int(reader: asyncio.StreamReader) -> int: """ Read a 64-bit unsigned integer from an async stream. - + :param reader: Stream reader to read from. :type reader: asyncio.StreamReader :return: Integer value read from stream. @@ -231,16 +246,16 @@ async def close_server(server: Server): class Command(enum.Enum): """ Enumeration of protocol commands for ezmsg network communication. - + Defines all command types used in the ezmsg protocol for graph management, publisher-subscriber communication, and shared memory operations. """ - + @staticmethod def _generate_next_value_(name, start, count, last_values) -> bytes: """ Generate byte values for enum members. - + :param name: Name of the enum member. :type name: str :param start: Starting value (unused). @@ -278,6 +293,10 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SHUTDOWN = enum.auto() + CHANNEL = enum.auto() + SHM_OK = enum.auto() + SHM_ATTACH_FAILED = enum.auto() + def create_socket( host: str | None = None, @@ -288,10 +307,10 @@ def create_socket( ) -> socket.socket: """ Create a socket bound to an available port. - + Attempts to bind to the specified port, or searches for an available port within the given range if no specific port is provided. - + :param host: Host address to bind to (defaults to DEFAULT_HOST). :type host: str | None :param port: Specific port to bind to (if None, searches for available port). @@ -306,24 +325,31 @@ def create_socket( :rtype: socket.socket :raises IOError: If no available ports can be found in the specified range. """ - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if host is None: host = DEFAULT_HOST if port is not None: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # set REUSEADDR if a port is explicitly requested; leads to quick server restart on same port + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.bind((host, port)) return sock port = start_port while port <= max_port: if port not in ignore_ports: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: + # setting REUSEADDR during portscan can lead to race conditions during bind on Linux + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.bind((host, port)) return sock except OSError: + sock.close() pass + port += 1 raise IOError("Failed to bind socket; no free ports") diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 33155b0e..e12b71e1 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -5,15 +5,18 @@ from uuid import UUID from contextlib import suppress +from dataclasses import dataclass from .backpressure import Backpressure from .shm import SHMContext from .graphserver import GraphService -from .messagecache import MessageCache, Cache -from .messagemarshal import MessageMarshal, UndersizedMemory +from .channelmanager import CHANNELS +from .messagechannel import Channel +from .messagemarshal import MessageMarshal, UninitializedMemory from .netprotocol import ( Address, + AddressType, uint64_to_bytes, read_int, read_str, @@ -21,7 +24,7 @@ close_stream_writer, close_server, Command, - SubscriberInfo, + ChannelInfo, create_socket, DEFAULT_SHM_SIZE, PUBLISHER_START_PORT_ENV, @@ -36,15 +39,25 @@ BACKPRESSURE_REFRACTORY = 5.0 # sec +# Publisher needs a bit more information about connected channels +@dataclass +class PubChannelInfo(ChannelInfo): + pid: int + shm_ok: bool = False + + class Publisher: """ A publisher client for broadcasting messages to subscribers. - + Publisher manages shared memory allocation, connection handling with subscribers, backpressure control, and supports both shared memory and TCP transport methods. Messages are broadcast to all connected subscribers with automatic cleanup and resource management. """ + + _SENTINEL = object() + id: UUID pid: int topic: str @@ -52,8 +65,9 @@ class Publisher: _initialized: asyncio.Event _graph_task: "asyncio.Task[None]" _connection_task: "asyncio.Task[None]" - _subscribers: dict[UUID, SubscriberInfo] - _subscriber_tasks: dict[UUID, "asyncio.Task[None]"] + _channels: dict[UUID, PubChannelInfo] + _channel_tasks: dict[UUID, "asyncio.Task[None]"] + _local_channel: Channel _address: Address _backpressure: Backpressure _num_buffers: int @@ -63,13 +77,13 @@ class Publisher: _force_tcp: bool _last_backpressure_event: float - _graph_service: GraphService + _graph_address: AddressType | None @staticmethod def client_type() -> bytes: """ Get the client type identifier for publishers. - + :return: Command byte identifying this as a publisher client. :rtype: bytes """ @@ -79,15 +93,17 @@ def client_type() -> bytes: async def create( cls, topic: str, - graph_service: GraphService, + graph_address: AddressType | None = None, host: str | None = None, port: int | None = None, buf_size: int = DEFAULT_SHM_SIZE, - **kwargs, + num_buffers: int = 32, + start_paused: bool = False, + force_tcp: bool = False, ) -> "Publisher": """ Create a new Publisher instance and register it with the graph server. - + :param topic: The topic this publisher will broadcast to. :type topic: str :param graph_service: Service for graph server communication. @@ -104,53 +120,82 @@ async def create( :return: Initialized and registered Publisher instance. :rtype: Publisher """ + graph_service = GraphService(graph_address) reader, writer = await graph_service.open_connection() + shm = await graph_service.create_shm(num_buffers, buf_size) + writer.write(Command.PUBLISH.value) - id = UUID(await read_str(reader)) - pub = cls(id, topic, graph_service, **kwargs) - writer.write(uint64_to_bytes(pub.pid)) - writer.write(encode_str(pub.topic)) - pub._shm = await graph_service.create_shm(pub._num_buffers, buf_size) + writer.write(encode_str(topic)) + + pub_id = UUID(await read_str(reader)) + pub = cls( + id=pub_id, + topic=topic, + shm=shm, + graph_address=graph_address, + num_buffers=num_buffers, + start_paused=start_paused, + force_tcp=force_tcp, + _guard=cls._SENTINEL, + ) start_port = int( os.getenv(PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT) ) sock = create_socket(host, port, start_port=start_port) - server = await asyncio.start_server(pub._on_connection, sock=sock) - pub._address = Address(*sock.getsockname()) - pub._address.to_stream(writer) - pub._graph_task = asyncio.create_task(pub._graph_connection(reader, writer)) - - async def serve() -> None: - try: - await server.serve_forever() - except asyncio.CancelledError: # FIXME: Poor form? - logger.debug("pubclient serve is Cancelled...") - finally: - await close_server(server) - - pub._connection_task = asyncio.create_task(serve(), name=f"pub_{str(id)}") - - def on_done(_: asyncio.Future) -> None: - logger.debug("Closing pub server task.") - - pub._connection_task.add_done_callback(on_done) - MessageCache[id] = Cache(pub._num_buffers) - await pub._initialized.wait() + server = await asyncio.start_server(pub._channel_connect, sock=sock) + pub._connection_task = asyncio.create_task( + pub._serve_channels(server), name=f"pub-{pub.id}: {pub.topic}" + ) + + # Notify GraphServer that our server is up + channel_server_address = Address(*sock.getsockname()) + channel_server_address.to_stream(writer) + result = await reader.read(1) # channels connect + if result != Command.COMPLETE.value: + logger.warning(f"Could not create publisher {topic=}") + + # Pass off graph connection keep-alive to publisher task + pub._graph_task = asyncio.create_task( + pub._graph_connection(reader, writer), + name=f"pub-{pub.id}: _graph_connection", + ) + + pub._local_channel = await CHANNELS.register_local_pub( + pub_id=pub.id, + local_backpressure=pub._backpressure, + graph_address=pub._graph_address, + ) + + logger.debug(f"created pub {pub.id=} {topic=} {channel_server_address=}") + return pub + async def _serve_channels(self, server: asyncio.Server) -> None: + try: + await server.serve_forever() + except asyncio.CancelledError: + logger.debug(f"{self.log_name} cancelled") + finally: + await close_server(server) + await CHANNELS.unregister(self.id, self.id, self._graph_address) + logger.debug(f"{self.log_name} done") + def __init__( self, id: UUID, topic: str, - graph_service: GraphService, + shm: SHMContext, + graph_address: AddressType | None = None, num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, + _guard = None ) -> None: """ Initialize a Publisher instance. - + DO NOT USE this constructor to make a Publisher; use `create` instead + :param id: Unique identifier for this publisher. :type id: UUID :param topic: The topic this publisher broadcasts to. @@ -164,41 +209,49 @@ def __init__( :param force_tcp: Whether to force TCP transport instead of shared memory. :type force_tcp: bool """ + if _guard is not self._SENTINEL: + raise TypeError( + "Publisher cannot be instantiated directly." + "Use 'await Publisher.create(...)' instead." + ) + self.id = id self.pid = os.getpid() self.topic = topic - + self._shm = shm self._msg_id = 0 - self._subscribers = dict() - self._subscriber_tasks = dict() + self._channels = dict() + self._channel_tasks = dict() self._running = asyncio.Event() if not start_paused: self._running.set() self._num_buffers = num_buffers self._backpressure = Backpressure(num_buffers) self._force_tcp = force_tcp - self._initialized = asyncio.Event() self._last_backpressure_event = -1 + self._graph_address = graph_address - self._graph_service = graph_service + @property + def log_name(self) -> str: + return f"pub_{self.topic}{str(self.id)}" def close(self) -> None: """ Close the publisher and cancel all associated tasks. - + Cancels graph connection, shared memory, connection server, and all subscriber handling tasks. """ self._graph_task.cancel() self._shm.close() self._connection_task.cancel() - for task in self._subscriber_tasks.values(): + for task in self._channel_tasks.values(): task.cancel() async def wait_closed(self) -> None: """ Wait for all publisher resources to be fully closed. - + Waits for shared memory cleanup, graph connection termination, connection server shutdown, and all subscriber tasks to complete. """ @@ -207,7 +260,7 @@ async def wait_closed(self) -> None: await self._graph_task with suppress(asyncio.CancelledError): await self._connection_task - for task in self._subscriber_tasks.values(): + for task in self._channel_tasks.values(): with suppress(asyncio.CancelledError): await task @@ -216,10 +269,10 @@ async def _graph_connection( ) -> None: """ Handle communication with the graph server. - + Processes commands from the graph server including COMPLETE, PAUSE, RESUME, and SYNC operations. - + :param reader: Stream reader for receiving commands from graph server. :type reader: asyncio.StreamReader :param writer: Stream writer for responding to graph server. @@ -231,9 +284,6 @@ async def _graph_connection( if not cmd: break - elif cmd == Command.COMPLETE.value: - self._initialized.set() - elif cmd == Command.PAUSE.value: self._running.clear() @@ -257,57 +307,56 @@ async def _graph_connection( finally: await close_stream_writer(writer) - async def _on_connection( + async def _channel_connect( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: """ - Handle new subscriber connections. - - Exchanges identification information with connecting subscribers - and sets up subscriber handling tasks. - + Handle new channel connections. + + Exchanges identification information with connecting channels + and sets up channel handling tasks. + :param reader: Stream reader for receiving subscriber info. :type reader: asyncio.StreamReader :param writer: Stream writer for sending publisher info. :type writer: asyncio.StreamWriter """ - id_str = await read_str(reader) - id = UUID(id_str) - pid = await read_int(reader) - topic = await read_str(reader) - - # Subscriber determines if they have SHM access - writer.write(encode_str(self._shm.name)) - shm_access = bool(await read_int(reader)) - - writer.write( - encode_str(str(self.id)) - + uint64_to_bytes(self.pid) - + encode_str(self.topic) - + uint64_to_bytes(self._num_buffers) - ) + cmd = await reader.read(1) - info = SubscriberInfo(id, writer, pid, topic, shm_access) - coro = self._handle_subscriber(info, reader) - self._subscriber_tasks[id] = asyncio.create_task(coro) + if len(cmd) == 0: + return - await writer.drain() + if cmd == Command.CHANNEL.value: + channel_id_str = await read_str(reader) + channel_id = UUID(channel_id_str) + writer.write(encode_str(self._shm.name)) + shm_ok = await reader.read(1) == Command.SHM_OK.value + pid = await read_int(reader) + info = PubChannelInfo(channel_id, writer, self.id, pid, shm_ok) + coro = self._handle_channel(info, reader) + self._channel_tasks[channel_id] = asyncio.create_task(coro) + writer.write(Command.COMPLETE.value + uint64_to_bytes(self._num_buffers)) + await writer.drain() - async def _handle_subscriber( - self, info: SubscriberInfo, reader: asyncio.StreamReader + else: + raise ValueError(f"Publisher {self.id}: unexpected command {cmd=}") + + + async def _handle_channel( + self, info: PubChannelInfo, reader: asyncio.StreamReader ) -> None: """ - Handle communication with a specific subscriber. - - Processes acknowledgments from subscribers and manages backpressure - control based on subscriber feedback. - - :param info: Information about the subscriber connection. - :type info: SubscriberInfo - :param reader: Stream reader for receiving subscriber messages. + Handle communication with a specific channel. + + Processes acknowledgments from channels and manages backpressure + control based on channel feedback. + + :param info: Information about the channel connection. + :type info: PubChannelInfo + :param reader: Stream reader for receiving channel messages. :type reader: asyncio.StreamReader """ - self._subscribers[info.id] = info + self._channels[info.id] = info try: while True: @@ -321,17 +370,17 @@ async def _handle_subscriber( self._backpressure.free(info.id, msg_id % self._num_buffers) except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Publisher {self.id}: Subscriber {id} connection fail") + logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") finally: self._backpressure.free(info.id) - await close_stream_writer(self._subscribers[info.id].writer) - del self._subscribers[info.id] + await close_stream_writer(self._channels[info.id].writer) + del self._channels[info.id] async def sync(self) -> None: """ Pause and drain backpressure. - + Temporarily pauses the publisher and waits for all pending messages to be acknowledged by subscribers. """ @@ -342,7 +391,7 @@ async def sync(self) -> None: def running(self) -> bool: """ Check if the publisher is currently running. - + :return: True if publisher is running and accepting broadcasts. :rtype: bool """ @@ -351,7 +400,7 @@ def running(self) -> bool: def pause(self) -> None: """ Pause the publisher to stop broadcasting messages. - + Messages sent to broadcast() will block until resumed. """ self._running.clear() @@ -359,7 +408,7 @@ def pause(self) -> None: def resume(self) -> None: """ Resume the publisher to allow broadcasting messages. - + Unblocks any pending broadcast() calls. """ self._running.set() @@ -367,10 +416,10 @@ def resume(self) -> None: async def broadcast(self, obj: Any) -> None: """ Broadcast a message to all connected subscribers. - + Handles message serialization, shared memory management, transport selection (local/SHM/TCP), and backpressure control automatically. - + :param obj: The object/message to broadcast to subscribers. :type obj: Any """ @@ -386,56 +435,76 @@ async def broadcast(self, obj: Any) -> None: self._last_backpressure_event = time.time() await self._backpressure.wait(buf_idx) - MessageCache[self.id].put(self._msg_id, obj) - - for sub in list(self._subscribers.values()): - if not self._force_tcp and sub.shm_access: - if sub.pid == self.pid: - sub.writer.write(Command.TX_LOCAL.value + msg_id_bytes) - - else: - try: - # Push cache to shm (if not already there) - MessageCache[self.id].push(self._msg_id, self._shm) - - except UndersizedMemory as e: - new_shm = await self._graph_service.create_shm( - self._num_buffers, e.req_size * 2 + # Get local channel and put variable there for local tx + self._local_channel.put_local(self._msg_id, obj) + + if self._force_tcp or any( + ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values() + ): + with MessageMarshal.serialize(self._msg_id, obj) as ( + total_size, + header, + buffers, + ): + total_size_bytes = uint64_to_bytes(total_size) + + if not self._force_tcp and any( + ch.pid != self.pid and ch.shm_ok for ch in self._channels.values() + ): + if self._shm.buf_size < total_size: + new_shm = await GraphService(self._graph_address).create_shm( + self._num_buffers, total_size * 2 ) for i in range(self._num_buffers): - with self._shm.buffer(i, readonly=True) as from_buf: - with new_shm.buffer(i) as to_buf: - MessageMarshal.copy_obj(from_buf, to_buf) + try: + with self._shm.buffer(i, readonly=True) as from_buf: + with new_shm.buffer(i) as to_buf: + MessageMarshal.copy_obj(from_buf, to_buf) + except UninitializedMemory: + pass self._shm.close() + await self._shm.wait_closed() self._shm = new_shm - MessageCache[self.id].push(self._msg_id, self._shm) - - sub.writer.write(Command.TX_SHM.value) - sub.writer.write(msg_id_bytes) - sub.writer.write(encode_str(self._shm.name)) - - else: - with MessageMarshal.serialize(self._msg_id, obj) as ser_obj: - total_size, header, buffers = ser_obj - total_size_bytes = uint64_to_bytes(total_size) - - sub.writer.write(Command.TX_TCP.value) - sub.writer.write(msg_id_bytes) - sub.writer.write(total_size_bytes) - sub.writer.write(header) - for buffer in buffers: - sub.writer.write(buffer) - - try: - await sub.writer.drain() - self._backpressure.lease(sub.id, buf_idx) - - except (ConnectionResetError, BrokenPipeError): - logger.debug( - f"Publisher {self.id}: Subscriber {sub.id} connection fail" - ) - continue + + with self._shm.buffer(buf_idx) as mem: + MessageMarshal._write(mem, header, buffers) + + for channel in self._channels.values(): + msg: bytes = b"" + + if self.pid == channel.pid and channel.shm_ok: + continue # Local transmission handled by channel.put + + elif ( + (not self._force_tcp) + and self.pid != channel.pid + and channel.shm_ok + ): + msg = ( + Command.TX_SHM.value + + msg_id_bytes + + encode_str(self._shm.name) + ) + + else: + msg = ( + Command.TX_TCP.value + + msg_id_bytes + + total_size_bytes + + header + + b"".join([buffer for buffer in buffers]) + ) + + try: + channel.writer.write(msg) + await channel.writer.drain() + self._backpressure.lease(channel.id, buf_idx) + + except (ConnectionResetError, BrokenPipeError): + logger.debug( + f"Publisher {self.id}: Channel {channel.id} connection fail" + ) self._msg_id += 1 diff --git a/src/ezmsg/core/server.py b/src/ezmsg/core/server.py deleted file mode 100644 index 2b10a736..00000000 --- a/src/ezmsg/core/server.py +++ /dev/null @@ -1,256 +0,0 @@ -import os -from collections.abc import Callable -import asyncio -import logging -import socket -import typing - -from contextlib import suppress -from threading import Thread, Event - -from .netprotocol import ( - Address, - AddressType, - close_server, - close_stream_writer, - create_socket, - SERVER_PORT_START_ENV, - SERVER_PORT_START_DEFAULT, -) - -logger = logging.getLogger("ezmsg") - - -class ThreadedAsyncServer(Thread): - """ - An asyncio server that runs in a dedicated loop in a separate thread. - - This class provides a foundation for running asyncio-based servers in their own - threads, allowing for concurrent operation with other parts of the application. - - :obj:`GraphServer` inherits from this class to implement specific server functionality. - """ - - _server_up: Event - _shutdown: Event - - _sock: socket.socket - _loop: asyncio.AbstractEventLoop - - def __init__(self): - """ - Initialize the threaded async server. - """ - super().__init__(daemon=True) - self._server_up = Event() - self._shutdown = Event() - - @property - def address(self) -> Address: - return Address(*self._sock.getsockname()) - - def start(self, address: AddressType | None = None) -> None: - if address is not None: - self._sock = create_socket(*address) - else: - start_port = int( - os.environ.get(SERVER_PORT_START_ENV, SERVER_PORT_START_DEFAULT) - ) - self._sock = create_socket(start_port=start_port) - - self._loop = asyncio.new_event_loop() - super().start() - self._server_up.wait() - - def stop(self) -> None: - self._shutdown.set() - self.join() - - def run(self) -> None: - """ - Run the async server in its own thread. - - This method starts a new asyncio event loop and runs the server's main - execution logic within that loop. - """ - try: - asyncio.set_event_loop(self._loop) - with suppress(asyncio.CancelledError): - self._loop.run_until_complete(self.amain()) - finally: - self._loop.stop() - self._loop.close() - - async def amain(self) -> None: - """ - Main asynchronous execution method for the server. - - This abstract method should be implemented by subclasses to define - the specific server behavior and async operations. - """ - await self.setup() - - server = await asyncio.start_server(self.api, sock=self._sock) - - async def monitor_shutdown() -> None: - await self._loop.run_in_executor(None, self._shutdown.wait) - await close_server(server) - - monitor_task = self._loop.create_task(monitor_shutdown()) - - self._server_up.set() - - try: - await server.serve_forever() - - finally: - await self.shutdown() - monitor_task.cancel() - with suppress(asyncio.CancelledError): - await monitor_task - - async def setup(self) -> None: - """ - Setup method called before server starts serving. - - This method can be overridden by subclasses to perform any - initialization needed before the server begins accepting connections. - """ - ... - - async def shutdown(self) -> None: - """ - Shutdown method called when server is stopping. - - This method can be overridden by subclasses to perform any - cleanup needed when the server is shutting down. - """ - ... - - async def api( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - """ - API handler for client connections. - - This abstract method must be implemented by subclasses to handle - the server's communication protocol with connected clients. - - :param reader: Stream reader for receiving data from clients. - :type reader: asyncio.StreamReader - :param writer: Stream writer for sending data to clients. - :type writer: asyncio.StreamWriter - :raises NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError - - -T = typing.TypeVar("T", bound=ThreadedAsyncServer) - - -class ServiceManager(typing.Generic[T]): - """ - Manages the lifecycle and connection of threaded async servers. - - This generic class provides utilities for ensuring servers are running, - managing connections, and handling server creation with proper addressing. - - :type T: A ThreadedAsyncServer subclass type. - """ - _address: Address | None = None - _factory: Callable[[], T] - - ADDR_ENV: str - PORT_DEFAULT: int - - def __init__( - self, - factory: Callable[[], T], - address: AddressType | None = None, - ) -> None: - """ - Initialize the service manager. - - :param factory: Factory function that creates server instances. - :type factory: collections.abc.Callable[[], T] - :param address: Optional address tuple (host, port) for the server. - :type address: AddressType | None - """ - self._factory = factory - if address is not None: - self._address = Address(*address) - - async def ensure(self) -> T | None: - """ - Ensure a server is running and accessible. - - Attempts to connect to an existing server. If connection fails and - we should ensure a server exists, creates a new server instance. - - :return: The server instance if one was created, None if using existing. - :rtype: T | None - :raises OSError: If connection fails and no server should be created. - """ - server = None - ensure_server = False - if self._address is None: - ensure_server = self.ADDR_ENV not in os.environ - - try: - reader, writer = await self.open_connection() - await close_stream_writer(writer) - - except OSError as ref_e: - if not ensure_server: - raise ref_e - - server = self.create_server() - - return server - - @property - def address(self) -> Address: - """ - Get the server address. - - :return: The configured address or default address if none set. - :rtype: Address - """ - return self._address if self._address is not None else self.default_address() - - @classmethod - def default_address(cls) -> Address: - """ - Get the default server address from environment or class constants. - - :return: Address parsed from environment variable or default host:port. - :rtype: Address - """ - address_str = os.environ.get(cls.ADDR_ENV, f"127.0.0.1:{cls.PORT_DEFAULT}") - return Address.from_string(address_str) - - def create_server(self) -> T: - """ - Create and start a new server instance. - - :return: The newly created and started server instance. - :rtype: T - """ - server = self._factory() - server.start(self._address) - self._address = server.address - return server - - async def open_connection( - self, - ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: - """ - Open a connection to the managed server. - - :return: Tuple of (reader, writer) for communicating with the server. - :rtype: tuple[asyncio.StreamReader, asyncio.StreamWriter] - :raises OSError: If connection cannot be established. - """ - return await asyncio.open_connection( - *(self.default_address() if self._address is None else self._address) - ) diff --git a/src/ezmsg/core/settings.py b/src/ezmsg/core/settings.py index 1ffad2a0..d8766387 100644 --- a/src/ezmsg/core/settings.py +++ b/src/ezmsg/core/settings.py @@ -18,10 +18,11 @@ class SettingsMeta(ABCMeta): """ Metaclass that automatically applies dataclass decorator to Settings classes. - + This metaclass ensures all Settings subclasses are automatically converted to frozen dataclasses, providing immutability and proper initialization. """ + def __new__( cls, name: str, @@ -31,7 +32,7 @@ def __new__( ) -> type["Settings"]: """ Create a new Settings class with dataclass transformation. - + :param name: Name of the class being created. :type name: str :param bases: Base classes for the new class. diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 7c0d33eb..279aeb9c 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -1,7 +1,6 @@ import asyncio from collections.abc import Generator import logging -import typing from dataclasses import dataclass, field from contextlib import contextmanager, suppress @@ -11,6 +10,7 @@ from .netprotocol import ( close_stream_writer, bytes_to_uint, + uint64_to_bytes, ) logger = logging.getLogger("ezmsg") @@ -28,10 +28,10 @@ def _ignore_shm(name, rtype): def _untracked_shm() -> Generator[None, None, None]: """ Disable SHM tracking within context - https://bugs.python.org/issue38119. - + This context manager temporarily disables shared memory tracking to work around a Python bug where shared memory segments are not properly cleaned up. - + :return: Context manager generator. :rtype: collections.abc.Generator[None, None, None] """ @@ -61,18 +61,17 @@ class SHMContext: This format repeats itself for every buffer in the SharedMemory block. """ - _shm: SharedMemory - _data_block_segs: list[slice] - num_buffers: int buf_size: int - monitor: asyncio.Future + _shm: SharedMemory + _data_block_segs: list[slice] + _graph_task: asyncio.Task[None] def __init__(self, name: str) -> None: """ Initialize SHMContext by connecting to an existing shared memory segment. - + :param name: The name of the shared memory segment to connect to. :type name: str :raises BufferError: If shared memory segment cannot be accessed. @@ -95,12 +94,12 @@ def __init__(self, name: str) -> None: ] @classmethod - def _create( + def attach( cls, shm_name: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "SHMContext": """ Create a new SHMContext with connection monitoring. - + :param shm_name: Name of the shared memory segment. :type shm_name: str :param reader: Stream reader for connection monitoring. @@ -111,22 +110,36 @@ def _create( :rtype: SHMContext """ context = cls(shm_name) + context._graph_task = asyncio.create_task( + context._graph_connection(reader, writer), name=f"{context.name}_monitor" + ) + return context - async def monitor() -> None: - try: - await reader.read() - logger.debug("Read from SHMContext monitor reader") - except asyncio.CancelledError: - pass - finally: - await close_stream_writer(writer) - - def close(_: asyncio.Future) -> None: - context.close() + async def _graph_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + await reader.read() + logger.debug(f"SHMContext {self.name} GraphServer disconnected; closing.") + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"SHMContext {self.name} GraphServer {type(e)}") + finally: + self._shm.close() + await close_stream_writer(writer) + + def __getitem__(self, idx: int) -> memoryview: + """ + Get a memory view of a specific buffer in the shared memory segment. - context.monitor = asyncio.create_task(monitor(), name=f"{shm_name}_monitor") - context.monitor.add_done_callback(close) - return context + :param idx: Index of the buffer to access. + :type idx: int + :return: A memoryview of the buffer. + :rtype: memoryview + :raises BufferError: If shared memory is no longer accessible. + """ + if self._shm.buf is None: + raise BufferError(f"cannot access {self.name}: server disconnected") + return self._shm.buf[self._data_block_segs[idx]] @contextmanager def buffer( @@ -134,7 +147,7 @@ def buffer( ) -> Generator[memoryview, None, None]: """ Get a memory view of a specific buffer in the shared memory segment. - + :param idx: Index of the buffer to access. :type idx: int :param readonly: Whether to provide read-only access to the buffer. @@ -143,58 +156,36 @@ def buffer( :rtype: collections.abc.Generator[memoryview, None, None] :raises BufferError: If shared memory is no longer accessible. """ - if self._shm.buf is None: - raise BufferError(f"cannot access {self._shm.name}: server disconnected") - - with self._shm.buf[self._data_block_segs[idx]] as mem: - if readonly: - ro_mem = mem.toreadonly() - yield ro_mem - ro_mem.release() - else: - yield mem + mem = self[idx].toreadonly() if readonly else self[idx] + try: + yield mem + finally: + mem.release() def close(self) -> None: """ Close the shared memory context and cancel monitoring. - + This initiates an asynchronous close operation and cancels the connection monitor task. """ - asyncio.create_task(self.close_shm(), name=f"Close {self._shm.name}") - self.monitor.cancel() - - async def close_shm(self) -> None: - """ - Asynchronously close the shared memory segment. - - Retries closing if BufferError is encountered, as the segment - may still be in use by other processes. - """ - while True: - try: - self._shm.close() - logger.debug("Closed SHM segment.") - return - except BufferError: - logger.debug("BufferError caught... Sleeping.") - await asyncio.sleep(1) + self._graph_task.cancel() async def wait_closed(self) -> None: """ Wait for the shared memory context to be fully closed. - + This method waits for the monitoring task to complete, indicating that the connection has been properly terminated. """ with suppress(asyncio.CancelledError): - await self.monitor + await self._graph_task @property def name(self) -> str: """ Get the name of the shared memory segment. - + :return: The shared memory segment name. :rtype: str """ @@ -204,7 +195,7 @@ def name(self) -> str: def size(self) -> int: """ Get the usable size of each buffer (excluding header). - + :return: Buffer size minus 16-byte header. :rtype: int """ @@ -215,22 +206,32 @@ def size(self) -> int: class SHMInfo: """ Information about a shared memory segment and its active leases. - + Tracks the SharedMemory object and manages client connection leases. When all leases are released, the shared memory is automatically cleaned up. """ + shm: SharedMemory leases: set["asyncio.Task[None]"] = field(default_factory=set) + @classmethod + def create(cls, num_buffers: int, buf_size: int) -> "SHMInfo": + buf_size += 16 * num_buffers # Repeated header info makes this a bit bigger + shm = SharedMemory(size=num_buffers * buf_size, create=True) + shm.buf[:] = b"0" * len(shm.buf) # Guarantee zeros + shm.buf[0:8] = uint64_to_bytes(num_buffers) + shm.buf[8:16] = uint64_to_bytes(buf_size) + return cls(shm) + def lease( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "asyncio.Task[None]": """ Create a lease for this shared memory segment. - + The lease monitors the client connection and automatically releases the shared memory when the client disconnects. - + :param reader: Stream reader to monitor for client disconnection. :type reader: asyncio.StreamReader :param writer: Stream writer for connection cleanup. @@ -238,9 +239,12 @@ def lease( :return: Task representing the active lease. :rtype: asyncio.Task[None] """ + async def _wait_for_eof() -> None: try: await reader.read() + except (ConnectionResetError, BrokenPipeError): + pass finally: await close_stream_writer(writer) @@ -251,7 +255,7 @@ async def _wait_for_eof() -> None: def _release(self, task: "asyncio.Task[None]"): self.leases.discard(task) - logger.debug(f"discarded lease from {self.shm.name}") + logger.debug(f"discarded lease from {self.shm.name}; {len(self.leases)} left") if len(self.leases) == 0: logger.debug(f"unlinking {self.shm.name}") self.shm.close() diff --git a/src/ezmsg/core/state.py b/src/ezmsg/core/state.py index 385f3520..c298fd7d 100644 --- a/src/ezmsg/core/state.py +++ b/src/ezmsg/core/state.py @@ -7,10 +7,11 @@ class StateMeta(ABCMeta): """ Metaclass that automatically applies dataclass decorator to State classes. - + This metaclass ensures all State subclasses are automatically converted to mutable dataclasses with hash support but no automatic initialization. """ + def __new__( cls, name: str, @@ -20,7 +21,7 @@ def __new__( ) -> type["State"]: """ Create a new State class with dataclass transformation. - + :param name: Name of the class being created. :type name: str :param bases: Base classes for the new class. diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index 99593f1d..2af20f3c 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -7,7 +7,7 @@ class Stream(Addressable): """ Base class for all streams in the ezmsg framework. - + Streams define the communication channels between components, carrying messages of a specific type through the system. @@ -23,13 +23,13 @@ def __init__(self, msg_type: Any): def __repr__(self) -> str: _addr = self.address if self._location is not None else "unlocated" - return f"Stream:{_addr}[{self.msg_type}]" + return f"Stream:{_addr}[{self.msg_type.__name__}]" class InputStream(Stream): """ Can be added to any Component as a member variable. Methods may subscribe to it. - + InputStream represents a channel that receives messages from other components. Units can subscribe to InputStreams to process incoming messages. @@ -44,7 +44,7 @@ def __repr__(self) -> str: class OutputStream(Stream): """ Can be added to any Component as a member variable. Methods may publish to it. - + OutputStream represents a channel that sends messages to other components. Units can publish to OutputStreams to send messages through the system. diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index b9697a91..fee285cc 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -1,6 +1,4 @@ -import os import asyncio -from collections.abc import AsyncGenerator import logging import typing @@ -9,21 +7,15 @@ from copy import deepcopy from .graphserver import GraphService -from .shm import SHMContext -from .messagecache import MessageCache, Cache -from .messagemarshal import MessageMarshal +from .channelmanager import CHANNELS +from .messagechannel import NotificationQueue, Channel from .netprotocol import ( - Address, - UINT64_SIZE, - uint64_to_bytes, - bytes_to_uint, - read_int, + AddressType, read_str, encode_str, close_stream_writer, Command, - PublisherInfo, ) @@ -33,31 +25,39 @@ class Subscriber: """ A subscriber client for receiving messages from publishers. - + Subscriber manages connections to multiple publishers, handles different transport methods (local, shared memory, TCP), and provides both copying and zero-copy message access patterns with automatic acknowledgment. """ + + _SENTINEL = object() + id: UUID - pid: int topic: str + _graph_address: AddressType | None + _graph_task: asyncio.Task[None] + _cur_pubs: set[UUID] + _incoming: NotificationQueue + + # FIXME: This event allows Subscriber.create to block until + # incoming initial connections (UPDATE) has completed. The + # logic is confusing and difficult to follow, but this event + # serves an important purpose for the time being. _initialized: asyncio.Event - _graph_task: "asyncio.Task[None]" - _publishers: dict[UUID, PublisherInfo] - _publisher_tasks: dict[UUID, "asyncio.Task[None]"] - _shms: dict[UUID, SHMContext] - _incoming: "asyncio.Queue[tuple[UUID, int]]" - _graph_service: GraphService + # NOTE: This is an optimization to retain a local handle to channels + # so that dict lookup and wrapper contextmanager aren't in hotpath + _channels: dict[UUID, Channel] @classmethod async def create( - cls, topic: str, graph_service: GraphService, **kwargs + cls, topic: str, graph_address: AddressType | None, **kwargs ) -> "Subscriber": """ Create a new Subscriber instance and register it with the graph server. - + :param topic: The topic this subscriber will listen to. :type topic: str :param graph_service: Service for graph server communication. @@ -68,22 +68,41 @@ async def create( :return: Initialized and registered Subscriber instance. :rtype: Subscriber """ - reader, writer = await graph_service.open_connection() + reader, writer = await GraphService(graph_address).open_connection() writer.write(Command.SUBSCRIBE.value) - id_str = await read_str(reader) - sub = cls(UUID(id_str), topic, graph_service, **kwargs) - writer.write(uint64_to_bytes(sub.pid)) - writer.write(encode_str(sub.topic)) - sub._graph_task = asyncio.create_task(sub._graph_connection(reader, writer)) + writer.write(encode_str(topic)) + sub_id_str = await read_str(reader) + sub_id = UUID(sub_id_str) + + sub = cls(sub_id, topic, graph_address, _guard=cls._SENTINEL, **kwargs) + + sub._graph_task = asyncio.create_task( + sub._graph_connection(reader, writer), + name=f"sub-{sub.id}: _graph_connection", + ) + + # FIXME: We need to wait for _graph_task to service an UPDATE + # then receive a COMPLETE before we return a fully connected + # subscriber ready for recv. await sub._initialized.wait() + + logger.debug(f"created sub {sub.id=} {topic=}") + return sub def __init__( - self, id: UUID, topic: str, graph_service: GraphService, **kwargs + self, + id: UUID, + topic: str, + graph_address: AddressType | None, + _guard = None, + **kwargs ) -> None: """ Initialize a Subscriber instance. - + + DO NOT USE this constructor, use Subscriber.create instead. + :param id: Unique identifier for this subscriber. :type id: UUID :param topic: The topic this subscriber listens to. @@ -92,55 +111,48 @@ def __init__( :type graph_service: GraphService :param kwargs: Additional keyword arguments (unused). """ + if _guard is not self._SENTINEL: + raise TypeError( + "Subscriber cannot be instantiated directly." + "Use 'await Subscriber.create(...)' instead." + ) self.id = id - self.pid = os.getpid() self.topic = topic + self._graph_address = graph_address - self._publishers = dict() - self._publisher_tasks = dict() - self._shms = dict() + self._cur_pubs = set() self._incoming = asyncio.Queue() + self._channels = dict() self._initialized = asyncio.Event() - self._graph_service = graph_service - def close(self) -> None: """ Close the subscriber and cancel all associated tasks. - + Cancels graph connection, all publisher connection tasks, and closes all shared memory contexts. """ self._graph_task.cancel() - for task in self._publisher_tasks.values(): - task.cancel() - for shm in self._shms.values(): - shm.close() async def wait_closed(self) -> None: """ Wait for all subscriber resources to be fully closed. - + Waits for graph connection termination, all publisher connection tasks to complete, and all shared memory contexts to close. """ with suppress(asyncio.CancelledError): await self._graph_task - for task in self._publisher_tasks.values(): - with suppress(asyncio.CancelledError): - await task - for shm in self._shms.values(): - await shm.wait_closed() async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: """ Handle communication with the graph server. - + Processes commands from the graph server including COMPLETE and UPDATE operations for managing publisher connections. - + :param reader: Stream reader for receiving commands from graph server. :type reader: asyncio.StreamReader :param writer: Stream writer for responding to graph server. @@ -149,151 +161,75 @@ async def _graph_connection( try: while True: cmd = await reader.read(1) + if not cmd: break - elif cmd == Command.COMPLETE.value: + if cmd == Command.COMPLETE.value: + # FIXME: The only time GraphServer will send us a COMPLETE + # is when it is done with Subscriber.create. Unfortunately + # part of creating the subscriber involves receiving an + # UPDATE with all of the initial connections so that + # messaging resolves as expected immediately after creation + # is completed. The only way we can service this UPDATE + # is by passing control of the GraphServer's StreamReader + # to this task, which can handle the UPDATE -- mid-creation + # then we receive the COMPLETE in here, set the _initialized + # event, which we wait on in Subscriber.create, releasing the + # block and returning a fully connected/ready subscriber. + # While this currently works, its non-obvious what this + # accomplishes and why its implemented this way. I just + # wasted 2 hours removing this seemingly un-necessary event + # to introduce a bug that is only resolved by reintroducing + # the event. Some thought should be put into replacing the + # bespoke communication protocol with something a bit more + # standard (JSON RPC?) with a more common/logical pattern + # for creating/handling comms. We will probably want to + # keep the bespoke comms for the publisher->channel link + # as that is in the hot path. self._initialized.set() elif cmd == Command.UPDATE.value: - pub_addresses: dict[UUID, Address] = {} - connections = await read_str(reader) - connections = connections.strip(",") - if len(connections): - for connection in connections.split(","): - pub_id, pub_address = connection.split("@") - pub_id = UUID(pub_id) - pub_address = Address.from_string(pub_address) - pub_addresses[pub_id] = pub_address - - for id in set(pub_addresses.keys() - self._publishers.keys()): - connected = asyncio.Event() - coro = self._handle_publisher(id, pub_addresses[id], connected) - task_name = f"sub{self.id}:_handle_publisher({id})" - self._publisher_tasks[id] = asyncio.create_task( - coro, name=task_name + update = await read_str(reader) + pub_ids = ( + set([UUID(id) for id in update.split(",")]) if update else set() + ) + + for pub_id in set(pub_ids - self._cur_pubs): + channel = await CHANNELS.register( + pub_id, self.id, self._incoming, self._graph_address ) - await connected.wait() + self._channels[pub_id] = channel - for id in set(self._publishers.keys() - pub_addresses.keys()): - self._publisher_tasks[id].cancel() - with suppress(asyncio.CancelledError): - await self._publisher_tasks[id] + for pub_id in set(self._cur_pubs - pub_ids): + await CHANNELS.unregister(pub_id, self.id, self._graph_address) + del self._channels[pub_id] writer.write(Command.COMPLETE.value) await writer.drain() else: logger.warning( - f"Subscriber {self.id} rx unknown command from GraphServer: {cmd}" + f"Subscriber {self.topic}({self.id}) rx unknown command from GraphServer: {cmd}" ) except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Subscriber {self.id} lost connection to graph server") + logger.debug( + f"Subscriber {self.topic}({self.id}) lost connection to graph server" + ) finally: + for pub_id in self._channels: + await CHANNELS.unregister(pub_id, self.id, self._graph_address) await close_stream_writer(writer) - async def _handle_publisher( - self, id: UUID, address: Address, connected: asyncio.Event - ) -> None: - """ - Handle communication with a specific publisher. - - Establishes connection, exchanges identification, and processes - incoming messages from the publisher using various transport methods. - - :param id: Unique identifier of the publisher. - :type id: UUID - :param address: Network address of the publisher. - :type address: Address - :param connected: Event to signal when connection is established. - :type connected: asyncio.Event - :raises ValueError: If publisher ID doesn't match expected ID. - """ - reader, writer = await asyncio.open_connection(*address) - writer.write(encode_str(str(self.id))) - writer.write(uint64_to_bytes(self.pid)) - writer.write(encode_str(self.topic)) - await writer.drain() - - # Pub replies with current shm name - # We attempt to attach and let pub know if we have SHM access - shm_name = await read_str(reader) - try: - (await self._graph_service.attach_shm(shm_name)).close() - writer.write(uint64_to_bytes(1)) - except (ValueError, OSError): - writer.write(uint64_to_bytes(0)) - await writer.drain() - - pub_id_str = await read_str(reader) - pub_pid = await read_int(reader) - pub_topic = await read_str(reader) - num_buffers = await read_int(reader) - - if id != UUID(pub_id_str): - raise ValueError("Unexpected Publisher ID") - - # NOTE: Not thread safe - if id not in MessageCache: - MessageCache[id] = Cache(num_buffers) - - self._publishers[id] = PublisherInfo(id, writer, pub_pid, pub_topic, address) - - connected.set() - - try: - while True: - msg = await reader.read(1) - if not msg: - break - - msg_id_bytes = await reader.read(UINT64_SIZE) - msg_id = bytes_to_uint(msg_id_bytes) - - if msg == Command.TX_SHM.value: - shm_name = await read_str(reader) - - if id not in self._shms or self._shms[id].name != shm_name: - if id in self._shms: - self._shms[id].close() - try: - self._shms[id] = await self._graph_service.attach_shm( - shm_name - ) - except ValueError: - logger.info( - "Invalid SHM received from publisher; may be dead" - ) - raise - - # FIXME: TCP connections could be more efficient. - # https://github.com/iscoe/ezmsg/issues/5 - elif msg == Command.TX_TCP.value: - buf_size = await read_int(reader) - obj_bytes = await reader.readexactly(buf_size) - - with MessageMarshal.obj_from_mem(memoryview(obj_bytes)) as obj: - MessageCache[id].put(msg_id, obj) - - self._incoming.put_nowait((id, msg_id)) - - except (ConnectionResetError, BrokenPipeError): - logger.debug(f"connection fail: sub:{self.id} -> pub:{id}") - - finally: - await close_stream_writer(self._publishers[id].writer) - del self._publishers[id] - logger.debug(f"disconnected: sub:{self.id} -> pub:{id}") - async def recv(self) -> typing.Any: """ Receive the next message with a deep copy. - + This method creates a deep copy of the received message, allowing safe modification without affecting the original cached message. - + :return: Deep copy of the received message. :rtype: typing.Any """ @@ -303,30 +239,18 @@ async def recv(self) -> typing.Any: return out_msg @asynccontextmanager - async def recv_zero_copy(self) -> AsyncGenerator[typing.Any, None]: + async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: """ Receive the next message with zero-copy access. - + This context manager provides direct access to the cached message without copying. The message should not be modified or stored beyond the context manager's scope. - + :return: Context manager yielding the received message. :rtype: collections.abc.AsyncGenerator[typing.Any, None] """ - id, msg_id = await self._incoming.get() - msg_id_bytes = uint64_to_bytes(msg_id) + pub_id, msg_id = await self._incoming.get() - try: - shm = self._shms.get(id, None) - with MessageCache[id].get(msg_id, shm) as msg: - yield msg - - finally: - if id in self._publishers: - try: - ack = Command.RX_ACK.value + msg_id_bytes - self._publishers[id].writer.write(ack) - await self._publishers[id].writer.drain() - except (BrokenPipeError, ConnectionResetError): - logger.debug(f"ack fail: sub:{self.id} -> pub:{id}") + with self._channels[pub_id].get(msg_id, self.id) as msg: + yield msg diff --git a/src/ezmsg/core/test.py b/src/ezmsg/core/test.py index ce1e705b..48ec4d95 100644 --- a/src/ezmsg/core/test.py +++ b/src/ezmsg/core/test.py @@ -1,15 +1,18 @@ - # tutorial_pipeline.py +# tutorial_pipeline.py import ezmsg.core as ez from dataclasses import dataclass from collections.abc import AsyncGenerator + class CountSettings(ez.Settings): iterations: int + @dataclass class CountMessage: value: int + class Count(ez.Unit): SETTINGS = CountSettings @@ -24,6 +27,7 @@ async def count(self) -> AsyncGenerator: raise ez.NormalTermination + class AddOne(ez.Unit): INPUT_COUNT = ez.InputStream(CountMessage) OUTPUT_PLUS_ONE = ez.OutputStream(CountMessage) @@ -33,6 +37,7 @@ class AddOne(ez.Unit): async def on_message(self, message) -> AsyncGenerator: yield self.OUTPUT_PLUS_ONE, CountMessage(value=message.value + 1) + class PrintValue(ez.Unit): INPUT = ez.InputStream(CountMessage) @@ -40,13 +45,14 @@ class PrintValue(ez.Unit): async def on_message(self, message) -> None: print(message.value) + components = { "COUNT": Count(settings=CountSettings(iterations=10)), "ADD_ONE": AddOne(), - "PRINT": PrintValue() + "PRINT": PrintValue(), } connections = ( (components["COUNT"].OUTPUT_COUNT, components["ADD_ONE"].INPUT_COUNT), - (components["ADD_ONE"].OUTPUT_PLUS_ONE, components["PRINT"].INPUT) + (components["ADD_ONE"].OUTPUT_PLUS_ONE, components["PRINT"].INPUT), ) -ez.run(components=components, connections=connections) \ No newline at end of file +ez.run(components=components, connections=connections) diff --git a/src/ezmsg/core/unit.py b/src/ezmsg/core/unit.py index 045019a1..d957d06c 100644 --- a/src/ezmsg/core/unit.py +++ b/src/ezmsg/core/unit.py @@ -62,11 +62,11 @@ def __init__( class Unit(Component, metaclass=UnitMeta): """ Represents a single step in the graph. - + Units can subscribe, publish, and have tasks. Units are the fundamental building blocks of ezmsg applications that perform actual computation and message processing. To create a Unit, inherit from the Unit class. - + :param settings: Optional settings object for unit configuration :type settings: Settings | None """ @@ -97,11 +97,11 @@ async def setup(self) -> None: async def initialize(self) -> None: """ Runs when the Unit is instantiated. - + This is called from within the same process this unit will live in. - This lifecycle hook can be overridden. It can be run as async functions + This lifecycle hook can be overridden. It can be run as async functions by simply adding the async keyword when overriding. - + This method is where you should initialize your unit's state and prepare for message processing. """ @@ -110,11 +110,11 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """ Runs when the Unit terminates. - + This is called from within the same process this unit will live in. - This lifecycle hook can be overridden. It can be run as async functions + This lifecycle hook can be overridden. It can be run as async functions by simply adding the async keyword when overriding. - + This method is where you should clean up resources and perform any necessary shutdown procedures. """ @@ -124,7 +124,7 @@ async def shutdown(self) -> None: def publisher(stream: OutputStream): """ A decorator for a method that publishes to a stream in the task/messaging thread. - + An async function will yield messages on the designated :obj:`OutputStream`. A function can have both ``@subscriber`` and ``@publisher`` decorators. @@ -135,7 +135,7 @@ def publisher(stream: OutputStream): :raises ValueError: If stream is not an OutputStream Example usage: - + .. code-block:: python from collections.abc import AsyncGenerator @@ -163,8 +163,8 @@ def pub_factory(func): def subscriber(stream: InputStream, zero_copy: bool = False): """ A decorator for a method that subscribes to a stream in the task/messaging thread. - - An async function will run once per message received from the :obj:`InputStream` + + An async function will run once per message received from the :obj:`InputStream` it subscribes to. A function can have both ``@subscriber`` and ``@publisher`` decorators. :param stream: The input stream to subscribe to @@ -203,7 +203,7 @@ def sub_factory(func): def main(func: Callable): """ A decorator which designates this function to run as the main thread for this Unit. - + A Unit may only have one main function. The main function runs independently of the message processing and is typically used for initialization, background processing, or cleanup tasks. @@ -220,7 +220,7 @@ def main(func: Callable): def timeit(func: Callable): """ A decorator that logs the execution time of the decorated function. - + ezmsg will log the amount of time this function takes to execute to the ezmsg logger. This is useful for performance monitoring and optimization. The execution time is logged in milliseconds. @@ -252,7 +252,7 @@ def wrapper(self, *args, **kwargs): def thread(func: Callable): """ A decorator which designates this function to run as a background thread for this Unit. - + Thread functions run concurrently with the main message processing and can be used for background tasks, monitoring, or other concurrent operations. @@ -268,7 +268,7 @@ def thread(func: Callable): def task(func: Callable): """ A decorator which designates this function to run as a task in the task/messaging thread. - + Task functions are part of the main message processing pipeline and are executed within the unit's primary execution context. @@ -284,7 +284,7 @@ def task(func: Callable): def process(func: Callable): """ A decorator which designates this function to run in its own process. - + Process functions run in separate processes for isolation and can be used for CPU-intensive operations or when process isolation is required. diff --git a/src/ezmsg/core/util.py b/src/ezmsg/core/util.py index 54cae73a..3bb168f8 100644 --- a/src/ezmsg/core/util.py +++ b/src/ezmsg/core/util.py @@ -1,5 +1,4 @@ from collections.abc import Mapping -import sys import typing @@ -9,7 +8,7 @@ def is_dict_like(value: typing.Any) -> typing.TypeGuard[Mapping]: """ Check if a value behaves like a dictionary. - + This function checks if the value has the basic dictionary interface by verifying it has 'keys' and '__getitem__' attributes. @@ -29,7 +28,7 @@ def either_dict_or_kwargs( """ Handle flexible argument passing patterns for functions that accept either positional dict or keyword arguments. - + This utility function helps implement the common pattern where a function can accept either a dictionary as the first argument or keyword arguments, but not both. diff --git a/src/ezmsg/util/__init__.py b/src/ezmsg/util/__init__.py index ccc806ea..c722ac66 100644 --- a/src/ezmsg/util/__init__.py +++ b/src/ezmsg/util/__init__.py @@ -6,7 +6,7 @@ Key modules: - :mod:`ezmsg.util.debuglog`: Debug logging utilities -- :mod:`ezmsg.util.generator`: Generator-based message processing decorators +- :mod:`ezmsg.util.generator`: Generator-based message processing decorators - :mod:`ezmsg.util.messagecodec`: JSON encoding/decoding for message logging - :mod:`ezmsg.util.messagegate`: Message flow control and gating - :mod:`ezmsg.util.messagelogger`: File-based message logging diff --git a/src/ezmsg/util/messagecodec.py b/src/ezmsg/util/messagecodec.py index 779be16a..2e3db03f 100644 --- a/src/ezmsg/util/messagecodec.py +++ b/src/ezmsg/util/messagecodec.py @@ -10,6 +10,7 @@ The MessageEncoder and MessageDecoder classes handle automatic conversion between Python objects and JSON representations suitable for file logging. """ + from collections.abc import Iterable, Generator import json import pickle @@ -44,7 +45,7 @@ class LogStart: ... def type_str(obj: typing.Any) -> str: """ Get a string representation of an object's type for serialization. - + :param obj: Object to get type string for :type obj: typing.Any :return: String representation in format 'module:qualname' @@ -58,7 +59,7 @@ def type_str(obj: typing.Any) -> str: def import_type(typestr: str) -> type: """ Import a type from a string representation. - + :param typestr: String representation in format 'module:qualname' :type typestr: str :return: The imported type @@ -78,13 +79,14 @@ def import_type(typestr: str) -> type: class MessageEncoder(json.JSONEncoder): """ JSON encoder for ezmsg messages with support for dataclasses, numpy arrays, and arbitrary objects. - + This encoder extends the standard JSON encoder to handle: - + - Dataclass objects (serialized as dictionaries with type information) - NumPy arrays (serialized as base64-encoded data with metadata) - Other objects via pickle (as fallback) """ + def default(self, o: typing.Any): if is_dataclass(o): return { @@ -116,12 +118,13 @@ def default(self, o: typing.Any): class StampedMessage(typing.NamedTuple): """ A message with an associated timestamp. - + :param msg: The message object :type msg: typing.Any :param timestamp: Optional timestamp for the message :type timestamp: float | None """ + msg: typing.Any timestamp: float | None @@ -129,10 +132,10 @@ class StampedMessage(typing.NamedTuple): def _object_hook(obj: dict[str, typing.Any]) -> typing.Any: """ JSON object hook for decoding ezmsg messages. - + Handles reconstruction of dataclasses, numpy arrays, and pickled objects from their JSON representations. - + :param obj: Dictionary from JSON decoder :type obj: dict[str, typing.Any] :return: Reconstructed object @@ -145,9 +148,7 @@ def _object_hook(obj: dict[str, typing.Any]) -> typing.Any: if obj_type is not None: if np and obj_type == NDARRAY_TYPE: data_bytes: str | None = obj.get(NDARRAY_DATA) - data_shape: Iterable[int] | None = obj.get( - NDARRAY_SHAPE - ) + data_shape: Iterable[int] | None = obj.get(NDARRAY_SHAPE) data_dtype: npt.DTypeLike | None = obj.get(NDARRAY_DTYPE) if ( @@ -175,22 +176,21 @@ def _object_hook(obj: dict[str, typing.Any]) -> typing.Any: class MessageDecoder(json.JSONDecoder): """ JSON decoder for ezmsg messages. - + Automatically reconstructs dataclasses, numpy arrays, and pickled objects from their JSON representations using the _object_hook function. """ + def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=_object_hook, *args, **kwargs) - - def message_log( fname: Path, return_object: bool = True ) -> Generator[typing.Any, None, None]: """ Generator function to read messages from a log file created by MessageLogger. - + :param fname: Path to the log file :type fname: Path :param return_object: If True, yield only the message objects; if False, yield complete log entries diff --git a/src/ezmsg/util/messagegate.py b/src/ezmsg/util/messagegate.py index 5d9e35c4..379c95b9 100644 --- a/src/ezmsg/util/messagegate.py +++ b/src/ezmsg/util/messagegate.py @@ -9,7 +9,7 @@ class GateMessage: """ Send this message to ``INPUT_GATE`` to open or close the gate. - + :param open: True to open the gate (allow messages), False to close it (discard messages) :type open: bool """ @@ -66,7 +66,7 @@ async def initialize(self) -> None: def set_gate(self, set_open: bool) -> None: """ Set the gate open/closed state and reset message counter. - + :param set_open: True to open gate, False to close it :type set_open: bool """ diff --git a/src/ezmsg/util/messagelogger.py b/src/ezmsg/util/messagelogger.py index d91b9fb0..cd713f6c 100644 --- a/src/ezmsg/util/messagelogger.py +++ b/src/ezmsg/util/messagelogger.py @@ -16,7 +16,7 @@ def log_object(obj: typing.Any) -> str: """ Convert an object to a JSON string with timestamp for logging. - + :param obj: Object to convert to log string :type obj: typing.Any :return: JSON string containing timestamp and object @@ -86,7 +86,7 @@ class MessageLogger(ez.Unit): def open_file(self, filepath: Path) -> Path | None: """ Open a file for message logging. - + :param filepath: Path to the file to open :type filepath: Path :return: File path if file successfully opened, otherwise None @@ -109,7 +109,7 @@ def open_file(self, filepath: Path) -> Path | None: def close_file(self, filepath: Path) -> Path | None: """ Close a file that was being used for message logging. - + :param filepath: Path to the file to close :type filepath: Path :return: File path if file successfully closed, otherwise None diff --git a/src/ezmsg/util/messages/axisarray.py b/src/ezmsg/util/messages/axisarray.py index 6d627538..5324cc57 100644 --- a/src/ezmsg/util/messages/axisarray.py +++ b/src/ezmsg/util/messages/axisarray.py @@ -12,7 +12,6 @@ try: import numpy as np import numpy.typing as npt - import numpy.lib.stride_tricks as nps except ModuleNotFoundError: ez.logger.error( 'Install ezmsg with the AxisArray extra:pip install "ezmsg[AxisArray]"' @@ -36,6 +35,7 @@ @dataclass class AxisBase(ABC): """Abstract base class for axes types used by AxisArray.""" + unit: str = "" @typing.overload @@ -50,13 +50,13 @@ def value(self, x): @dataclass class LinearAxis(AxisBase): """ - An axis implementation for sparse axes with regular intervals between elements. - + An axis implementation for sparse axes with regular intervals between elements. + It is called "linear" because it provides a simple linear mapping between indices and element values: value = (index * gain) + offset. - A typical example is a time axis (TimeAxis), with regular sampling rate. - + A typical example is a time axis (TimeAxis), with regular sampling rate. + :param gain: Step size (scaling factor) for the linear axis :type gain: float :param offset: The offset (value) of the first sample @@ -64,6 +64,7 @@ class LinearAxis(AxisBase): :param unit: The unit of measurement for this axis (inherited from AxisBase) :type unit: str """ + gain: float = 1.0 offset: float = 0.0 @@ -74,7 +75,7 @@ def value(self, x: npt.NDArray[np.int_]) -> npt.NDArray[np.float64]: ... def value(self, x): """ Convert index(es) to axis value(s) using the linear transformation. - + :param x: Index or array of indices to convert :type x: int | npt.NDArray[:class:`numpy.int_`] :return: Corresponding axis value(s) @@ -105,7 +106,7 @@ def index(self, v, fn=np.rint): def create_time_axis(cls, fs: float, offset: float = 0.0) -> "LinearAxis": """ Convenience method to construct a LinearAxis for time. - + :param fs: Sampling frequency in Hz :type fs: float :param offset: Time offset in seconds (default: 0.0) @@ -120,17 +121,18 @@ def create_time_axis(cls, fs: float, offset: float = 0.0) -> "LinearAxis": class ArrayWithNamedDims: """ Base class for arrays with named dimensions. - + This class provides a foundation for arrays where each dimension has a name, enabling more intuitive array manipulation and access patterns. - + :param data: The underlying numpy array data :type data: npt.NDArray :param dims: List of dimension names, must match the array's number of dimensions :type dims: list[str] - + :raises ValueError: If dims length doesn't match data.ndim or contains duplicate names """ + data: npt.NDArray dims: list[str] @@ -154,10 +156,10 @@ def __eq__(self, other): class CoordinateAxis(AxisBase, ArrayWithNamedDims): """ An axis implementation that uses explicit coordinate values stored in an array. - + This class allows for non-linear or irregularly spaced coordinate systems by storing the actual coordinate values in a data array. - + Inherits from both AxisBase and ArrayWithNamedDims, combining axis functionality with named dimension support. @@ -168,6 +170,7 @@ class CoordinateAxis(AxisBase, ArrayWithNamedDims): :param dims: List of dimension names, must match the array's number of dimensions (inherited from ArrayWithNamedDims) :type dims: list[str] """ + @typing.overload def value(self, x: int) -> typing.Any: ... @typing.overload @@ -175,7 +178,7 @@ def value(self, x: npt.NDArray[np.int_]) -> npt.NDArray: ... def value(self, x): """ Get coordinate value(s) at the given index(es). - + :param x: Index or array of indices to lookup :type x: int | npt.NDArray[:class:`numpy.int_`] :return: Coordinate value(s) at the specified index(es) @@ -188,15 +191,15 @@ def value(self, x): class AxisArray(ArrayWithNamedDims): """ A lightweight message class comprising a numpy ndarray and its metadata. - + AxisArray extends ArrayWithNamedDims to provide a complete data structure for scientific computing with named dimensions, axis coordinate systems, and metadata. It's designed to be similar to xarray.DataArray but optimized for message passing in streaming applications. - + :param data: The underlying numpy array data (inherited from ArrayWithNamedDims) :type data: npt.NDArray - :param dims: List of dimension names (inherited from ArrayWithNamedDims) + :param dims: List of dimension names (inherited from ArrayWithNamedDims) :type dims: list[str] :param axes: Dictionary mapping dimension names to their axis coordinate systems :type axes: dict[str, AxisBase] @@ -205,6 +208,7 @@ class AxisArray(ArrayWithNamedDims): :param key: Optional key identifier for this array, typically used to specify source device (default is empty string) :type key: str """ + axes: dict[str, AxisBase] = field(default_factory=dict) attrs: dict[str, typing.Any] = field(default_factory=dict) key: str = "" @@ -216,7 +220,7 @@ def __eq__(self, other): # returns NotImplemented if classes aren't equal. Unintuitively, # NotImplemented seems to evaluate as 'True' in an if statement. equal = super().__eq__(other) - if equal != True: + if equal is not True: return equal # checks for AxisArray fields @@ -233,10 +237,11 @@ def __eq__(self, other): class Axis(LinearAxis): """ Deprecated alias for LinearAxis. - + .. deprecated:: 3.6.0 Use :class:`LinearAxis` instead. """ + def __post_init__(self) -> None: warnings.warn( "AxisArray.Axis is a deprecated alias for LinearAxis", @@ -263,19 +268,20 @@ def index(self, val): class AxisInfo: """ Container for axis information including the axis object, index, and size. - + This class provides a convenient way to access both the axis coordinate system and metadata about where that axis appears in the array structure. - + :param axis: The axis coordinate system :type axis: AxisBase :param idx: The index of this axis in the AxisArray object's dimension list - :type idx: int + :type idx: int :param size: The size of this dimension (stored as None for CoordinateAxis which determines size from data) :type size: int | None - + :raises ValueError: If size rules are violated for the axis type """ + axis: AxisBase idx: int # TODO (kpilch): rename this to _size as preferred usage is len(obj), not obj.size @@ -300,7 +306,7 @@ def __len__(self) -> int: def indices(self) -> npt.NDArray[np.int_]: """ Get array of all valid indices for this axis. - + :return: Array of indices from 0 to len(self)-1 :rtype: npt.NDArray[:class:`numpy.int_`] """ @@ -310,7 +316,7 @@ def indices(self) -> npt.NDArray[np.int_]: def values(self) -> npt.NDArray: """ Get array of coordinate values for all indices of this axis. - + :return: Array of coordinate values computed from axis.value(indices) :rtype: npt.NDArray """ @@ -323,13 +329,13 @@ def isel( ) -> T: """ Select data using integer-based indexing along specified dimensions. - + This method allows for flexible indexing using integers, slices, or arrays of integers to select subsets of the data along named dimensions. - + :param indexers: Dictionary of {dimension_name: indexer} pairs :type indexers: typing.Any | None - :param indexers_kwargs: Alternative way to specify indexers as keyword arguments + :param indexers_kwargs: Alternative way to specify indexers as keyword arguments :type indexers_kwargs: typing.Any :return: New AxisArray with selected data :rtype: T @@ -368,14 +374,14 @@ def sel( ) -> T: """ Select data using label-based indexing along specified dimensions. - + This method allows selection using real-world coordinate values rather than integer indices. Currently supports only slice objects and LinearAxis. - + :param indexers: Dictionary of {dimension_name: slice_indexer} pairs :type indexers: typing.Any | None :param indexers_kwargs: Alternative way to specify indexers as keyword arguments - :type indexers_kwargs: typing.Any + :type indexers_kwargs: typing.Any :return: New AxisArray with selected data :rtype: T :raises ValueError: If indexer is not a slice or axis is not a LinearAxis @@ -400,7 +406,7 @@ def sel( def shape(self) -> tuple[int, ...]: """ Shape of data. - + :return: Tuple representing the shape of the underlying data array :rtype: tuple[int, ...] """ @@ -409,10 +415,10 @@ def shape(self) -> tuple[int, ...]: def to_xr_dataarray(self) -> "DataArray": """ Convert this AxisArray to an xarray DataArray. - + This method creates an xarray DataArray with equivalent data, coordinates, dimensions, and attributes. Useful for interoperability with the xarray ecosystem. - + :return: xarray DataArray representation of this AxisArray :rtype: xarray.DataArray """ @@ -430,7 +436,7 @@ def to_xr_dataarray(self) -> "DataArray": def ax(self, dim: str | int) -> AxisInfo: """ Get AxisInfo for a specified dimension. - + :param dim: Dimension name or index :type dim: str | int :return: AxisInfo containing axis, index, and size information @@ -445,7 +451,7 @@ def ax(self, dim: str | int) -> AxisInfo: def get_axis(self, dim: str | int) -> AxisBase: """ Get the axis coordinate system for a specified dimension. - + :param dim: Dimension name or index :type dim: str | int :return: The axis coordinate system (defaults to LinearAxis if not specified) @@ -461,7 +467,7 @@ def get_axis(self, dim: str | int) -> AxisBase: def get_axis_name(self, dim: int) -> str: """ Get the dimension name for a given axis index. - + :param dim: The axis index :type dim: int :return: The dimension name @@ -472,7 +478,7 @@ def get_axis_name(self, dim: int) -> str: def get_axis_idx(self, dim: str) -> int: """ Get the axis index for a given dimension name. - + :param dim: The dimension name :type dim: str :return: The axis index @@ -487,7 +493,7 @@ def get_axis_idx(self, dim: str) -> int: def axis_idx(self, dim: str | int) -> int: """ Get the axis index for a given dimension name or pass through if already an int. - + :param dim: Dimension name or index :type dim: str | int :return: The axis index @@ -502,7 +508,7 @@ def _axis_idx(self, dim: str | int) -> int: def as2d(self, dim: str | int) -> npt.NDArray: """ Get a 2D view of the data with the specified dimension as the first axis. - + :param dim: Dimension name or index to move to first axis :type dim: str | int :return: 2D array view with shape (dim_size, remaining_elements) @@ -510,15 +516,13 @@ def as2d(self, dim: str | int) -> npt.NDArray: """ return as2d(self.data, self.axis_idx(dim), xp=get_namespace(self.data)) - def iter_over_axis( - self: T, axis: str | int - ) -> Generator[T, None, None]: + def iter_over_axis(self: T, axis: str | int) -> Generator[T, None, None]: """ Iterate over slices along the specified axis. - + Yields AxisArray objects for each slice along the given axis, with that dimension removed from the resulting arrays. - + :param axis: Dimension name or index to iterate over :type axis: str | int :yields: AxisArray objects for each slice along the axis @@ -540,16 +544,14 @@ def iter_over_axis( yield it_aa @contextmanager - def view2d( - self, dim: str | int - ) -> Generator[npt.NDArray, None, None]: + def view2d(self, dim: str | int) -> Generator[npt.NDArray, None, None]: """ Context manager providing a 2D view of the data. - + Yields a 2D array view with the specified dimension as the first axis. Changes to the yielded array may be reflected in the original data. - - :param dim: Dimension name or index to move to first axis + + :param dim: Dimension name or index to move to first axis :type dim: str | int :yields: 2D array view with shape (dim_size, remaining_elements) :rtype: collections.abc.Generator[npt.NDArray, None, None] @@ -560,8 +562,8 @@ def view2d( def shape2d(self, dim: str | int) -> tuple[int, int]: """ Get the 2D shape when viewing data with specified dimension first. - - :param dim: Dimension name or index + + :param dim: Dimension name or index :type dim: str | int :return: Tuple of (dim_size, remaining_elements) :rtype: tuple[int, int] @@ -577,7 +579,7 @@ def concatenate( ) -> T: """ Concatenate multiple AxisArray objects along a specified dimension. - + :param aas: Variable number of AxisArray objects to concatenate :type aas: T :param dim: Dimension name along which to concatenate @@ -637,7 +639,7 @@ def transpose( ) -> T: """ Transpose (reorder) the dimensions of an AxisArray. - + :param aa: The AxisArray to transpose :type aa: T :param dims: New dimension order (names or indices). If None, reverses all dimensions @@ -653,9 +655,7 @@ def transpose( return replace(aa, data=new_data, dims=new_dims, axes=aa.axes) -def slice_along_axis( - in_arr: npt.NDArray, sl: slice | int, axis: int -) -> npt.NDArray: +def slice_along_axis(in_arr: npt.NDArray, sl: slice | int, axis: int) -> npt.NDArray: """ Slice the input array along a specified axis using the given slice object or integer index. Integer arguments to `sl` will cause the sliced dimension to be dropped. @@ -689,7 +689,7 @@ def sliding_win_oneaxis( ) -> npt.NDArray: """ Generates a view of an array using a sliding window of specified length along a specified axis of the input array. - This is a slightly optimized version of nps.sliding_window_view with a few important differences: + This is a slightly optimized version of numpy.lib.stride_tricks.sliding_window_view with a few important differences: - This only accepts a single nwin and a single axis, thus we can skip some checks. - The new `win` axis precedes immediately the original target axis, unlike sliding_window_view where the @@ -757,7 +757,7 @@ def _as2d( ) -> tuple[npt.NDArray, tuple[int, ...]]: """ Internal helper function to reshape array to 2D with specified axis first. - + :param in_arr: Input array to be reshaped :type in_arr: npt.NDArray :param axis: Axis to move to first position (default: 0) @@ -781,7 +781,7 @@ def _as2d( def as2d(in_arr: npt.NDArray, axis: int = 0, *, xp) -> npt.NDArray: """ Reshape array to 2D with specified axis first. - + :param in_arr: Input array :type in_arr: npt.NDArray :param axis: Axis to move to first position (default: 0) @@ -795,9 +795,7 @@ def as2d(in_arr: npt.NDArray, axis: int = 0, *, xp) -> npt.NDArray: @contextmanager -def view2d( - in_arr: npt.NDArray, axis: int = 0 -) -> Generator[npt.NDArray, None, None]: +def view2d(in_arr: npt.NDArray, axis: int = 0) -> Generator[npt.NDArray, None, None]: """ Context manager providing 2D view of the array, no matter what input dimensionality is. Yields a view of underlying data when possible, changes to data in yielded @@ -806,8 +804,8 @@ def view2d( NOTE: In practice, I'm not sure this is very useful because it requires modifying the numpy array data in-place, which limits its application to zero-copy messages - NOTE: The context manager allows the use of `with` when calling `view2d`. - + NOTE: The context manager allows the use of `with` when calling `view2d`. + :param in_arr: Input array to be viewed as 2D :type in_arr: npt.NDArray :param axis: Dimension index to move to first axis (Default = 0) @@ -829,7 +827,7 @@ def view2d( def shape2d(arr: npt.NDArray, axis: int = 0) -> tuple[int, int]: """ Calculate the 2D shape when viewing array with specified axis first. - + :param arr: Input array :type arr: npt.NDArray :param axis: Axis to move to first position (default: 0) diff --git a/src/ezmsg/util/messages/chunker.py b/src/ezmsg/util/messages/chunker.py index 402fd2bc..20229ec2 100644 --- a/src/ezmsg/util/messages/chunker.py +++ b/src/ezmsg/util/messages/chunker.py @@ -1,7 +1,6 @@ import asyncio from collections.abc import Generator, AsyncGenerator import traceback -import typing import numpy as np import numpy.typing as npt @@ -21,7 +20,7 @@ def array_chunker( """ Create a generator that yields AxisArrays containing chunks of an array along a specified axis. - + The generator should be useful for quick offline analyses, tests, or examples. This generator probably is not useful for online streaming applications. @@ -39,7 +38,7 @@ def array_chunker( :rtype: collections.abc.Generator[AxisArray, None, None] """ - if not type(data) == np.ndarray: + if not type(data) == np.ndarray: # noqa: E721 # hot path optimization data = np.array(data) n_chunks = int(np.ceil(data.shape[axis] / chunk_len)) tvec = np.arange(n_chunks, dtype=float) * chunk_len / fs + tzero @@ -66,7 +65,7 @@ def array_chunker( class ArrayChunkerSettings(ez.Settings): """ Settings for ArrayChunker unit. - + Configuration for chunking array data along a specified axis with timing information. :param data: An array_like object to iterate over, chunk-by-chunk. @@ -80,6 +79,7 @@ class ArrayChunkerSettings(ez.Settings): :param tzero: The time offset of the first chunk. Will only be used to make the time axis. :type tzero: float """ + data: npt.ArrayLike chunk_len: int axis: int = 0 @@ -90,10 +90,11 @@ class ArrayChunkerSettings(ez.Settings): class ArrayChunker(ez.Unit): """ Unit for chunking array data along a specified axis. - + Converts array data into sequential chunks along a specified axis, with proper timing axis information for streaming applications. """ + SETTINGS = ArrayChunkerSettings STATE = GenState @@ -103,7 +104,7 @@ class ArrayChunker(ez.Unit): async def initialize(self) -> None: """ Initialize the ArrayChunker unit. - + Sets up the generator for chunking operations based on current settings. """ self.construct_generator() @@ -111,7 +112,7 @@ async def initialize(self) -> None: def construct_generator(self): """ Construct the chunking generator with current settings. - + Creates a new array_chunker generator instance using the unit's settings. """ self.STATE.gen = array_chunker( @@ -126,7 +127,7 @@ def construct_generator(self): async def on_settings(self, msg: ez.Settings) -> None: """ Handle incoming settings updates. - + :param msg: New settings to apply. :type msg: ez.Settings """ @@ -137,10 +138,10 @@ async def on_settings(self, msg: ez.Settings) -> None: async def send_chunk(self) -> AsyncGenerator: """ Publisher method that yields data chunks. - + Continuously yields chunks from the generator until exhausted, with proper exception handling for completion cases. - + :return: Async generator yielding AxisArray chunks. :rtype: collections.abc.AsyncGenerator """ diff --git a/src/ezmsg/util/messages/key.py b/src/ezmsg/util/messages/key.py index 23e982fe..1850f053 100644 --- a/src/ezmsg/util/messages/key.py +++ b/src/ezmsg/util/messages/key.py @@ -1,6 +1,5 @@ from collections.abc import AsyncGenerator, Generator import traceback -import typing import numpy as np @@ -32,22 +31,24 @@ def set_key(key: str = "") -> Generator[AxisArray, AxisArray, None]: class KeySettings(ez.Settings): """ Settings for key manipulation units. - + Configuration for setting or filtering AxisArray keys. - + :param key: The string to set as the key. :type key: str """ + key: str = "" class SetKey(ez.Unit): """ Unit for setting the key of incoming AxisArray messages. - + Modifies the key field of AxisArray messages while preserving all other data. Uses zero-copy operations for efficient processing. """ + STATE = GenState SETTINGS = KeySettings @@ -58,7 +59,7 @@ class SetKey(ez.Unit): async def initialize(self) -> None: """ Initialize the SetKey unit. - + Sets up the generator for key modification operations. """ self.construct_generator() @@ -66,7 +67,7 @@ async def initialize(self) -> None: def construct_generator(self): """ Construct the key-setting generator with current settings. - + Creates a new set_key generator instance using the unit's key setting. """ self.STATE.gen = set_key(key=self.SETTINGS.key) @@ -75,7 +76,7 @@ def construct_generator(self): async def on_settings(self, msg: ez.Settings) -> None: """ Handle incoming settings updates. - + :param msg: New settings to apply. :type msg: ez.Settings """ @@ -87,10 +88,10 @@ async def on_settings(self, msg: ez.Settings) -> None: async def on_message(self, message: AxisArray) -> AsyncGenerator: """ Process incoming AxisArray messages and set their keys. - + Uses zero-copy operations to efficiently modify the key field while preserving all other data. - + :param message: Input AxisArray to modify. :type message: AxisArray :return: Async generator yielding AxisArray with modified key. @@ -109,10 +110,10 @@ async def on_message(self, message: AxisArray) -> AsyncGenerator: class FilterOnKey(ez.Unit): """ Filter an AxisArray based on its key. - + Only passes through AxisArray messages whose key matches the configured key setting. Uses zero-copy operations for efficient filtering. - + Note: There is no associated generator method for this Unit because messages that fail the filter would still be yielded (as None), which complicates downstream processing. For contexts where filtering on key is desired but the ezmsg framework is not used, use normal Python functional programming. @@ -129,10 +130,10 @@ class FilterOnKey(ez.Unit): async def on_message(self, message: AxisArray) -> AsyncGenerator: """ Filter incoming AxisArray messages based on their key. - + Only yields messages whose key matches the configured filter key. Uses minimal 'touch' to prevent unnecessary deep copying by the framework. - + :param message: Input AxisArray to filter. :type message: AxisArray :return: Async generator yielding filtered AxisArray messages. diff --git a/src/ezmsg/util/messages/modify.py b/src/ezmsg/util/messages/modify.py index 19f784b3..49a9bf9c 100644 --- a/src/ezmsg/util/messages/modify.py +++ b/src/ezmsg/util/messages/modify.py @@ -1,6 +1,5 @@ from collections.abc import AsyncGenerator, Generator import traceback -import typing import numpy as np @@ -60,23 +59,25 @@ def modify_axis( class ModifyAxisSettings(ez.Settings): """ Settings for ModifyAxis unit. - + Configuration for modifying axis names and dimensions of AxisArray messages. :param name_map: A dictionary where the keys are the names of the old dims and the values are the new names. Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised. :type name_map: dict[str, str | None] | None """ + name_map: dict[str, str | None] | None = None class ModifyAxis(ez.Unit): """ Unit for modifying axis names and dimensions of AxisArray messages. - + Renames dimensions and axes according to a name mapping, with support for dropping dimensions. Uses zero-copy operations for efficient processing. """ + STATE = GenState SETTINGS = ModifyAxisSettings @@ -87,7 +88,7 @@ class ModifyAxis(ez.Unit): async def initialize(self) -> None: """ Initialize the ModifyAxis unit. - + Sets up the generator for axis modification operations. """ self.construct_generator() @@ -95,7 +96,7 @@ async def initialize(self) -> None: def construct_generator(self): """ Construct the axis-modifying generator with current settings. - + Creates a new modify_axis generator instance using the unit's name mapping. """ self.STATE.gen = modify_axis(name_map=self.SETTINGS.name_map) @@ -104,7 +105,7 @@ def construct_generator(self): async def on_settings(self, msg: ez.Settings) -> None: """ Handle incoming settings updates. - + :param msg: New settings to apply. :type msg: ez.Settings """ @@ -116,10 +117,10 @@ async def on_settings(self, msg: ez.Settings) -> None: async def on_message(self, message: AxisArray) -> AsyncGenerator: """ Process incoming AxisArray messages and modify their axes. - + Uses zero-copy operations to efficiently modify axis names and dimensions while preserving data integrity. - + :param message: Input AxisArray to modify. :type message: AxisArray :return: Async generator yielding AxisArray with modified axes. diff --git a/src/ezmsg/util/messages/util.py b/src/ezmsg/util/messages/util.py index 72094bde..a632f69e 100644 --- a/src/ezmsg/util/messages/util.py +++ b/src/ezmsg/util/messages/util.py @@ -9,7 +9,7 @@ def fast_replace(arr: typing.Generic[T], **kwargs) -> T: """ Fast replacement of dataclass fields with reduced safety. - + Unlike dataclasses.replace, this function does not check for type compatibility, nor does it check that the passed in fields are valid fields for the dataclass and not flagged as init=False. @@ -18,7 +18,7 @@ def fast_replace(arr: typing.Generic[T], **kwargs) -> T: To force ezmsg to use the legacy replace, set the environment variable: EZMSG_DISABLE_FAST_REPLACE Unset the variable to use this replace function. - + :param arr: The dataclass instance to create a modified copy of. :type arr: typing.Generic[T] :param kwargs: Field values to update in the new instance. diff --git a/src/ezmsg/util/perf/__init__.py b/src/ezmsg/util/perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py new file mode 100644 index 00000000..590e681f --- /dev/null +++ b/src/ezmsg/util/perf/analysis.py @@ -0,0 +1,499 @@ +import json +import dataclasses +import argparse +import html +import math +import webbrowser + +from pathlib import Path + +from ..messagecodec import MessageDecoder +from .envinfo import TestEnvironmentInfo, format_env_diff +from .run import get_datestamp +from .impl import ( + TestParameters, + Metrics, + TestLogEntry, +) + +import ezmsg.core as ez + +try: + import xarray as xr + import pandas as pd # xarray depends on pandas +except ImportError: + ez.logger.error("ezmsg perf analysis requires xarray") + raise + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf analysis requires numpy") + raise + +TEST_DESCRIPTION = """ +Configurations (config): +- fanin: many publishers to one subscriber +- fanout: one publisher to many subscribers +- relay: one publisher to one subscriber through many relays + +Communication strategies (comms): +- local: all subs, relays, and pubs are in the SAME process +- shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved +- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + +Variables: +- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- msg_size: nominal message size (bytes) + +Metrics: +- sample_rate: messages/sec at the sink (higher = better) +- data_rate: bytes/sec at the sink (higher = better) +- latency_mean: average send -> receive latency in seconds (lower = better) +""" + + +def load_perf(perf: Path) -> xr.Dataset: + all_results: dict[TestParameters, dict[int, list[Metrics]]] = dict() + + run_idx = 0 + + with open(perf, "r") as perf_f: + info: TestEnvironmentInfo = json.loads(next(perf_f), cls=MessageDecoder) + for line in perf_f: + obj = json.loads(line, cls=MessageDecoder) + if isinstance(obj, TestEnvironmentInfo): + run_idx += 1 + elif isinstance(obj, TestLogEntry): + runs = all_results.get(obj.params, dict()) + metrics = runs.get(run_idx, list()) + metrics.append(obj.results) + runs[run_idx] = metrics + all_results[obj.params] = runs + + n_clients_axis = list(sorted(set([p.n_clients for p in all_results.keys()]))) + msg_size_axis = list(sorted(set([p.msg_size for p in all_results.keys()]))) + comms_axis = list(sorted(set([p.comms for p in all_results.keys()]))) + config_axis = list(sorted(set([p.config for p in all_results.keys()]))) + + dims = ["n_clients", "msg_size", "comms", "config"] + coords = { + "n_clients": n_clients_axis, + "msg_size": msg_size_axis, + "comms": comms_axis, + "config": config_axis, + } + + data_vars = {} + for field in dataclasses.fields(Metrics): + m = ( + np.zeros( + ( + len(n_clients_axis), + len(msg_size_axis), + len(comms_axis), + len(config_axis), + ) + ) + * np.nan + ) + for p, a in all_results.items(): + # tests are run multiple times; get the median of means + m[ + n_clients_axis.index(p.n_clients), + msg_size_axis.index(p.msg_size), + comms_axis.index(p.comms), + config_axis.index(p.config), + ] = np.median( + [np.mean([getattr(v, field.name) for v in r]) for r in a.values()] + ) + data_vars[field.name] = xr.DataArray(m, dims=dims, coords=coords) + + dataset = xr.Dataset(data_vars, attrs=dict(info=info)) + return dataset + + +def _escape(s: str) -> str: + return html.escape(str(s), quote=True) + + +def _env_block(title: str, body: str) -> str: + return f""" +
+

{_escape(title)}

+
{_escape(body).strip()}
+
+ """ + + +def _legend_block() -> str: + return """ +
+

Legend

+ +
+ """ + + +def _base_css() -> str: + # Minimal, print-friendly CSS + color scales for cells. + return """ + + """ + + +def _color_for_comparison( + value: float, metric: str, noise_band_pct: float = 10.0 +) -> str: + """ + Returns inline CSS background for a comparison % value. + value: e.g., 97.3, 104.8, etc. + For sample_rate/data_rate: improvement > 100 (good). + For latency_mean: improvement < 100 (good). + Noise band ±10% around 100 is neutral. + """ + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return "" + + delta = value - 100.0 + # Determine direction: + is good for sample/data; - is good for latency + if "rate" in metric: + # positive delta good, negative bad + magnitude = abs(delta) + sign_good = delta > 0 + elif "latency" in metric: + # negative delta good (lower latency) + magnitude = abs(delta) + sign_good = delta < 0 + else: + return "" + + # Noise band: keep neutral + if magnitude <= noise_band_pct: + return "" + + # Scale 5%..50% across 0..1; clamp + scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0)) + + # Choose hue and lightness; use HSL with gentle saturation + hue = "var(--green)" if sign_good else "var(--red)" + # opacity via alpha blend on lightness via HSLa + # Use saturation ~70%, lightness around 40–50% blended with table bg + alpha = 0.15 + 0.35 * scale # 0.15..0.50 + return f"background-color: hsla({hue}, 70%, 45%, {alpha});" + + +def _format_number(x) -> str: + if isinstance(x, (int,)) and not isinstance(x, bool): + return f"{x:d}" + try: + xf = float(x) + except Exception: + return _escape(str(x)) + # Heuristic: for comparison percentages, 1 decimal is nice; for absolute, 3 decimals for latency. + return f"{xf:.3f}" + + +def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None: + """print perf test results and comparisons to the console""" + + output = "" + + perf = load_perf(perf_path) + info: TestEnvironmentInfo = perf.attrs["info"] + output += str(info) + "\n\n" + + relative = False + env_diff = None + if baseline_path is not None: + relative = True + output += "PERFORMANCE COMPARISON\n\n" + baseline = load_perf(baseline_path) + perf = (perf / baseline) * 100.0 + baseline_info: TestEnvironmentInfo = baseline.attrs["info"] + env_diff = format_env_diff(info.diff(baseline_info)) + output += env_diff + "\n\n" + + # These raw stats are still valuable to have, but are confusing + # when making relative comparisons + perf = perf.drop_vars(["latency_total", "num_msgs"]) + + perf = perf.stack(params=["n_clients", "msg_size"]).dropna("params") + df = perf.squeeze().to_dataframe() + df = df.drop("n_clients", axis=1) + df = df.drop("msg_size", axis=1) + + for _, config_ds in perf.groupby("config"): + for _, comms_ds in config_ds.groupby("comms"): + output += str(comms_ds.squeeze().to_dataframe()) + "\n\n" + output += "\n" + + print(output) + + if html: + # Ensure expected columns exist + expected_cols = { + "sample_rate_mean", + "sample_rate_median", + "data_rate", + "latency_mean", + "latency_median", + } + missing = expected_cols - set(df.columns) + if missing: + raise ValueError(f"Missing expected columns in dataset: {missing}") + + # We'll render a table per (config, comms) group. + groups = ( + df.reset_index() + .sort_values(by=["config", "comms", "n_clients", "msg_size"]) + .groupby(["config", "comms"], sort=False) + ) + + # Build HTML + parts: list[str] = [] + parts.append("") + parts.append( + "" + ) + parts.append("ezmsg perf report") + parts.append(_base_css()) + parts.append("
") + + parts.append("
") + parts.append("

ezmsg Performance Report

") + sub = str(perf_path) + if baseline_path is not None: + sub += f" relative to {str(baseline_path)}" + parts.append(f"
{_escape(sub)}
") + parts.append("
") + + if info is not None: + parts.append(_env_block("Test Environment", str(info))) + + parts.append(_env_block("Test Details", TEST_DESCRIPTION)) + + if env_diff is not None: + # Show diffs using your helper + parts.append("
") + parts.append("

Environment Differences vs Baseline

") + parts.append(f"
{_escape(env_diff)}
") + parts.append("
") + parts.append(_legend_block()) + + # Render each group + for (config, comms), g in groups: + # Keep only expected columns in order + cols = [ + "n_clients", + "msg_size", + "sample_rate_mean", + "sample_rate_median", + "data_rate", + "latency_mean", + "latency_median", + ] + g = g[cols].copy() + + # String format some columns (msg_size with separators) + g["msg_size"] = g["msg_size"].map( + lambda x: f"{int(x):,}" if pd.notna(x) else x + ) + + # Build table manually so we can inject inline cell styles easily + # (pandas Styler is great but produces bulky HTML; manual keeps it clean) + header = f""" + + + n_clients + msg_size {"" if relative else "(b)"} + sample_rate_mean {"" if relative else "(msgs/s)"} + sample_rate_median {"" if relative else "(msgs/s)"} + data_rate {"" if relative else "(MB/s)"} + latency_mean {"" if relative else "(us)"} + latency_median {"" if relative else "(us)"} + + + """ + body_rows: list[str] = [] + for _, row in g.iterrows(): + sr, srm, dr, lt, lm = ( + row["sample_rate_mean"], + row["sample_rate_median"], + row["data_rate"], + row["latency_mean"], + row["latency_median"], + ) + dr = dr if relative else dr / 2**20 + lt = lt if relative else lt * 1e6 + lm = lm if relative else lm * 1e6 + sr_style = ( + _color_for_comparison(sr, "sample_rate_mean") if relative else "" + ) + srm_style = ( + _color_for_comparison(srm, "sample_rate_median") if relative else "" + ) + dr_style = _color_for_comparison(dr, "data_rate") if relative else "" + lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" + lm_style = ( + _color_for_comparison(lm, "latency_median") if relative else "" + ) + + body_rows.append( + "" + f"{_format_number(row['n_clients'])}" + f"{_escape(row['msg_size'])}" + f"{_format_number(sr)}" + f"{_format_number(srm)}" + f"{_format_number(dr)}" + f"{_format_number(lt)}" + f"{_format_number(lm)}" + "" + ) + table_html = f"{header}{''.join(body_rows)}
" + + parts.append( + f"

" + f"{_escape(config)}" + f"{_escape(comms)}" + f"

{table_html}
" + ) + + parts.append("
") + html_text = "".join(parts) + + out_path = Path(f"report_{get_datestamp()}.html") + out_path.write_text(html_text, encoding="utf-8") + webbrowser.open(out_path.resolve().as_uri()) + + +def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_summary = subparsers.add_parser("summary", help="summarize performance results") + p_summary.add_argument( + "perf", + type=Path, + help="perf test", + ) + p_summary.add_argument( + "--baseline", + "-b", + type=Path, + default=None, + help="baseline perf test for comparison", + ) + p_summary.add_argument( + "--html", + action="store_true", + help="generate an html output file and render results in browser", + ) + + p_summary.set_defaults( + _handler=lambda ns: summary( + perf_path=ns.perf, baseline_path=ns.baseline, html=ns.html + ) + ) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py new file mode 100644 index 00000000..21fed7eb --- /dev/null +++ b/src/ezmsg/util/perf/command.py @@ -0,0 +1,19 @@ +import argparse + +from .analysis import setup_summary_cmdline +from .run import setup_run_cmdline + + +def command() -> None: + parser = argparse.ArgumentParser(description="ezmsg perf test utility") + subparsers = parser.add_subparsers(dest="command", required=True) + + setup_run_cmdline(subparsers) + setup_summary_cmdline(subparsers) + + ns = parser.parse_args() + ns._handler(ns) + + +if __name__ == "__main__": + command() diff --git a/src/ezmsg/util/perf/envinfo.py b/src/ezmsg/util/perf/envinfo.py new file mode 100644 index 00000000..654454f0 --- /dev/null +++ b/src/ezmsg/util/perf/envinfo.py @@ -0,0 +1,106 @@ +import dataclasses +import datetime +import platform +import typing +import sys +import subprocess + +import ezmsg.core as ez + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +try: + import psutil +except ImportError: + ez.logger.error("ezmsg perf requires psutil") + raise + + +def _git_commit() -> str: + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + except (subprocess.CalledProcessError, FileNotFoundError): + return "unknown" + + +def _git_branch() -> str: + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + except (subprocess.CalledProcessError, FileNotFoundError): + return "unknown" + + +@dataclasses.dataclass +class TestEnvironmentInfo: + ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) + numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) + python_version: str = dataclasses.field( + default_factory=lambda: sys.version.replace("\n", " ") + ) + os: str = dataclasses.field(default_factory=lambda: platform.system()) + os_version: str = dataclasses.field(default_factory=lambda: platform.version()) + machine: str = dataclasses.field(default_factory=lambda: platform.machine()) + processor: str = dataclasses.field(default_factory=lambda: platform.processor()) + cpu_count_logical: int | None = dataclasses.field( + default_factory=lambda: psutil.cpu_count(logical=True) + ) + cpu_count_physical: int | None = dataclasses.field( + default_factory=lambda: psutil.cpu_count(logical=False) + ) + memory_gb: float = dataclasses.field( + default_factory=lambda: round(psutil.virtual_memory().total / (1024**3), 2) + ) + start_time: str = dataclasses.field( + default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds") + ) + git_commit: str = dataclasses.field(default_factory=_git_commit) + git_branch: str = dataclasses.field(default_factory=_git_branch) + + def __str__(self) -> str: + fields = dataclasses.asdict(self) + width = max(len(k) for k in fields) + lines = ["TestEnvironmentInfo:"] + for key, value in fields.items(): + lines.append(f" {key.ljust(width)} : {value}") + return "\n".join(lines) + + def diff( + self, other: "TestEnvironmentInfo" + ) -> dict[str, tuple[typing.Any, typing.Any]]: + """Return a structured diff: {field: (self_value, other_value)} for changed fields.""" + a = dataclasses.asdict(self) + b = dataclasses.asdict(other) + keys = set(a) | set(b) + return {k: (a.get(k), b.get(k)) for k in keys if a.get(k) != b.get(k)} + + +def format_env_diff(diffs: dict[str, tuple[typing.Any, typing.Any]]) -> str: + """Pretty-print the structured diff in the same aligned style.""" + if not diffs: + return "No differences." + width = max(len(k) for k in diffs) + lines = ["Differences in TestEnvironmentInfo:"] + for k in sorted(diffs): + left, right = diffs[k] + lines.append(f" {k.ljust(width)} : {left} != {right}") + return "\n".join(lines) + + +def diff_envs(a: TestEnvironmentInfo, b: TestEnvironmentInfo) -> str: + return format_env_diff(a.diff(b)) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py new file mode 100644 index 00000000..0e8460c8 --- /dev/null +++ b/src/ezmsg/util/perf/impl.py @@ -0,0 +1,379 @@ +import asyncio +import dataclasses +import os +import time +import typing +import enum +import sys + +import ezmsg.core as ez + +from ezmsg.util.messages.util import replace +from ezmsg.core.netprotocol import Address + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +from .util import stable_perf + +try: + StrEnum = enum.StrEnum # type: ignore[attr-defined] +except AttributeError: + class StrEnum(str, enum.Enum): + """Fallback for Python < 3.11 where enum.StrEnum is unavailable.""" + + pass + +TIME = time.monotonic +if sys.platform.startswith('win'): + TIME = time.perf_counter + +def collect( + components: typing.Optional[typing.Mapping[str, ez.Component]] = None, + network: ez.NetworkDefinition = (), + process_components: typing.Collection[ez.Component] | None = None, + **components_kwargs: ez.Component, +) -> ez.Collection: + """collect a grouping of pre-configured components into a new "Collection" """ + from ezmsg.core.util import either_dict_or_kwargs + + components = either_dict_or_kwargs(components, components_kwargs, "collect") + if components is None: + raise ValueError("Must supply at least one component to run") + + out = ez.Collection() + for name, comp in components.items(): + comp._set_name(name) + out._components = ( + components # FIXME: Component._components should be typehinted as a Mapping + ) + out.network = lambda: network + out.process_components = ( + lambda: (out,) if process_components is None else process_components + ) + return out + + +@dataclasses.dataclass +class Metrics: + num_msgs: int + sample_rate_mean: float + sample_rate_median: float + latency_mean: float + latency_median: float + latency_total: float + data_rate: float + + +class LoadTestSettings(ez.Settings): + max_duration: float + num_msgs: int + dynamic_size: int + buffers: int + force_tcp: bool + + +@dataclasses.dataclass +class LoadTestSample: + _timestamp: float + counter: int + dynamic_data: np.ndarray + key: str + + +class LoadTestSourceState(ez.State): + counter: int = 0 + + +class LoadTestSource(ez.Unit): + OUTPUT = ez.OutputStream(LoadTestSample) + SETTINGS = LoadTestSettings + STATE = LoadTestSourceState + + async def initialize(self) -> None: + self.OUTPUT.num_buffers = self.SETTINGS.buffers + self.OUTPUT.force_tcp = self.SETTINGS.force_tcp + + @ez.publisher(OUTPUT) + async def publish(self) -> typing.AsyncGenerator: + ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") + start_time = TIME() + for _ in range(self.SETTINGS.num_msgs): + current_time = TIME() + if current_time - start_time >= self.SETTINGS.max_duration: + break + + yield ( + self.OUTPUT, + LoadTestSample( + _timestamp=TIME(), + counter=self.STATE.counter, + dynamic_data=np.zeros( + int(self.SETTINGS.dynamic_size // 4), dtype=np.float32 + ), + key=self.name, + ), + ) + self.STATE.counter += 1 + + ez.logger.info("Exiting publish") + raise ez.Complete + + async def shutdown(self) -> None: + ez.logger.info(f"Samples sent: {self.STATE.counter}") + + +class LoadTestRelay(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + OUTPUT = ez.OutputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy=True) + @ez.publisher(OUTPUT) + async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator: + yield self.OUTPUT, msg + + +class LoadTestReceiverState(ez.State): + # Tuples of sent timestamp, received timestamp, counter, dynamic size + received_data: list[tuple[float, float, int]] = dataclasses.field( + default_factory=list + ) + counters: dict[str, int] = dataclasses.field(default_factory=dict) + + +class LoadTestReceiver(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + SETTINGS = LoadTestSettings + STATE = LoadTestReceiverState + + async def initialize(self) -> None: + ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + counter = self.STATE.counters.get(sample.key, -1) + if sample.counter != counter + 1: + ez.logger.warning(f"{sample.counter - counter - 1} samples skipped!") + self.STATE.received_data.append( + (sample._timestamp, TIME(), sample.counter) + ) + self.STATE.counters[sample.key] = sample.counter + + +class LoadTestSink(LoadTestReceiver): + INPUT = ez.InputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + await super().receive(sample) + if len(self.STATE.received_data) == self.SETTINGS.num_msgs: + raise ez.NormalTermination + + @ez.task + async def terminate(self) -> None: + # Wait for the max duration of the load test + await asyncio.sleep(self.SETTINGS.max_duration) + ez.logger.warning("TIMEOUT -- terminating test.") + raise ez.NormalTermination + + +### TEST CONFIGURATIONS + + +@dataclasses.dataclass +class ConfigSettings: + n_clients: int + settings: LoadTestSettings + source: LoadTestSource + sink: LoadTestSink + + +Configuration = typing.Tuple[typing.Iterable[ez.Component], ez.NetworkDefinition] +Configurator = typing.Callable[[ConfigSettings], Configuration] + + +def fanout(config: ConfigSettings) -> Configuration: + """one pub to many subs""" + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + subs = [LoadTestReceiver(config.settings) for _ in range(config.n_clients)] + for sub in subs: + connections.append((config.source.OUTPUT, sub.INPUT)) + + return subs, connections + + +def fanin(config: ConfigSettings) -> Configuration: + """many pubs to one sub""" + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + pubs = [LoadTestSource(config.settings) for _ in range(config.n_clients)] + expected_num_msgs = config.sink.SETTINGS.num_msgs * len(pubs) + config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs=expected_num_msgs) # type: ignore + for pub in pubs: + connections.append((pub.OUTPUT, config.sink.INPUT)) + return pubs, connections + + +def relay(config: ConfigSettings) -> Configuration: + """one pub to one sub through many relays""" + connections: ez.NetworkDefinition = [] + + relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] + if len(relays): + connections.append((config.source.OUTPUT, relays[0].INPUT)) + for from_relay, to_relay in zip(relays[:-1], relays[1:]): + connections.append((from_relay.OUTPUT, to_relay.INPUT)) + connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + else: + connections.append((config.source.OUTPUT, config.sink.INPUT)) + + return relays, connections + + +CONFIGS: typing.Mapping[str, Configurator] = { + c.__name__: c for c in [fanin, fanout, relay] +} + + +class Communication(StrEnum): + LOCAL = "local" + SHM = "shm" + SHM_SPREAD = "shm_spread" + TCP = "tcp" + TCP_SPREAD = "tcp_spread" + + +def perform_test( + n_clients: int, + max_duration: float, + num_msgs: int, + msg_size: int, + buffers: int, + comms: Communication, + config: Configurator, + graph_address: Address, +) -> Metrics: + settings = LoadTestSettings( + dynamic_size=int(msg_size), + num_msgs=num_msgs, + max_duration=max_duration, + buffers=buffers, + force_tcp=(comms in (Communication.TCP, Communication.TCP_SPREAD)), + ) + + source = LoadTestSource(settings) + sink = LoadTestSink(settings) + + components: typing.Mapping[str, ez.Component] = dict( + SINK=sink, + ) + + clients, connections = config(ConfigSettings(n_clients, settings, source, sink)) + + # The 'sink' MUST remain in this process for us to pull its state. + process_components: typing.Iterable[ez.Component] = [] + if comms == Communication.LOCAL: + # Every component in the same process (this one) + components["SOURCE"] = source + for i, client in enumerate(clients): + components[f"CLIENT_{i + 1}"] = client + + else: + if comms in (Communication.SHM_SPREAD, Communication.TCP_SPREAD): + # Every component in its own process. + components["SOURCE"] = source + process_components.append(source) + for i, client in enumerate(clients): + components[f"CLIENT_{i + 1}"] = client + process_components.append(client) + + else: + # All clients and the source in ONE other process. + collect_comps: typing.Mapping[str, ez.Component] = dict() + collect_comps["SOURCE"] = source + for i, client in enumerate(clients): + collect_comps[f"CLIENT_{i + 1}"] = client + proc_collection = collect(components=collect_comps) + components["PROC"] = proc_collection + process_components = [proc_collection] + + with stable_perf(): + ez.run( + components=components, + connections=connections, + process_components=process_components, + graph_address=graph_address, + ) + + return calculate_metrics(sink) + + +def calculate_metrics(sink: LoadTestSink) -> Metrics: + # Log some useful summary statistics + min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) + max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) + latency = [ + receive_timestamp - send_timestamp + for send_timestamp, receive_timestamp, _ in sink.STATE.received_data + ] + total_latency = abs(sum(latency)) + + counters = list(sorted(t[2] for t in sink.STATE.received_data)) + dropped_samples = sum( + [max((x1 - x0) - 1, 0) for x1, x0 in zip(counters[1:], counters[:-1])] + ) + + rx_timestamps = np.array([rx_ts for _, rx_ts, _ in sink.STATE.received_data]) + rx_timestamps.sort() + runtime = max_timestamp - min_timestamp + num_samples = len(sink.STATE.received_data) + samplerate_mean = num_samples / runtime + diff_timestamps = np.diff(rx_timestamps) + diff_timestamps = diff_timestamps[np.nonzero(diff_timestamps)] + samplerate_median = 1.0 / float(np.median(diff_timestamps)) + latency_mean = total_latency / num_samples + latency_median = list(sorted(latency))[len(latency) // 2] + total_data = num_samples * sink.SETTINGS.dynamic_size + data_rate = total_data / runtime + + ez.logger.info(f"Samples received: {num_samples}") + ez.logger.info(f"Mean sample rate: {samplerate_mean} Hz") + ez.logger.info(f"Median sample rate: {samplerate_median} Hz") + ez.logger.info(f"Mean latency: {latency_mean} s") + ez.logger.info(f"Median latency: {latency_median} s") + ez.logger.info(f"Total latency: {total_latency} s") + ez.logger.info(f"Data rate: {data_rate * 1e-6} MB/s") + + if dropped_samples: + ez.logger.error( + f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", + ) + + return Metrics( + num_msgs=num_samples, + sample_rate_mean=samplerate_mean, + sample_rate_median=samplerate_median, + latency_mean=latency_mean, + latency_median=latency_median, + latency_total=total_latency, + data_rate=data_rate, + ) + + +@dataclasses.dataclass(unsafe_hash=True) +class TestParameters: + msg_size: int + num_msgs: int + n_clients: int + config: str + comms: str + max_duration: float + num_buffers: int + + +@dataclasses.dataclass +class TestLogEntry: + params: TestParameters + results: Metrics diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py new file mode 100644 index 00000000..3f794f21 --- /dev/null +++ b/src/ezmsg/util/perf/run.py @@ -0,0 +1,345 @@ +import os +import sys +import json +import itertools +import argparse +import typing +import random +import time + +from datetime import datetime, timedelta +from contextlib import contextmanager, redirect_stdout, redirect_stderr + +import ezmsg.core as ez +from ezmsg.core.graphserver import GraphServer + +from ..messagecodec import MessageEncoder +from .envinfo import TestEnvironmentInfo +from .util import warmup +from .impl import ( + TestParameters, + TestLogEntry, + perform_test, + Communication, + CONFIGS, +) + +DEFAULT_MSG_SIZES = [2**4, 2**20] +DEFAULT_N_CLIENTS = [1, 16] +DEFAULT_COMMS = [c for c in Communication] + + +# --- Output Suppression Context Manager --- +@contextmanager +def suppress_output(verbose: bool = False): + """Context manager to redirect stdout and stderr to os.devnull""" + if verbose: + yield + else: + # Open the null device for writing + with open(os.devnull, "w") as fnull: + # Redirect both stdout and stderr to the null device + with redirect_stderr(fnull): + with redirect_stdout(fnull): + yield + + +def _check_for_quit_default() -> bool: + return False + + +CHECK_FOR_QUIT = _check_for_quit_default + +if sys.platform.startswith("win"): + import msvcrt + + def _check_for_quit_win() -> bool: + """ + Checks for the 'q' key press in a non-blocking way. + Returns True if 'q' is pressed (case-insensitive), False otherwise. + """ + # Windows: Use msvcrt for non-blocking keyboard hit detection + if msvcrt.kbhit(): # type: ignore + # Read the key press (returns bytes) + key = msvcrt.getch() # type: ignore + try: + # Decode and check for 'q' + return key.decode().lower() == "q" + except UnicodeDecodeError: + # Handle potential non-text key presses gracefully + return False + return False + + CHECK_FOR_QUIT = _check_for_quit_win + +else: + import select + + def _check_for_quit() -> bool: + """ + Checks for the 'q' key press in a non-blocking way. + Returns True if 'q' is pressed (case-insensitive), False otherwise. + """ + # Linux/macOS: Use select to check if stdin has data + # select.select(rlist, wlist, xlist, timeout) + # timeout=0 makes it non-blocking + if sys.stdin.isatty(): + i, o, e = select.select([sys.stdin], [], [], 0) # type: ignore + if i: + # Read the buffered character + key = sys.stdin.read(1) + return key.lower() == "q" + return False + + CHECK_FOR_QUIT = _check_for_quit + + +def get_datestamp() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def perf_run( + max_duration: float, + num_msgs: int, + num_buffers: int, + iters: int, + repeats: int, + msg_sizes: list[int] | None, + n_clients: list[int] | None, + comms: typing.Iterable[str] | None, + configs: typing.Iterable[str] | None, + grid: bool, + warmup_dur: float, +) -> None: + if n_clients is None: + n_clients = DEFAULT_N_CLIENTS + if any(c < 0 for c in n_clients): + ez.logger.error("All tests must have >=0 clients") + return + + if msg_sizes is None: + msg_sizes = DEFAULT_MSG_SIZES + if any(s < 0 for s in msg_sizes): + ez.logger.error("All msg_sizes must be >=0 bytes") + + if not grid and len(list(n_clients)) != len(list(msg_sizes)): + ez.logger.warning( + "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + + f"{len(n_clients)=} which is not equal to {len(msg_sizes)=}. " + ) + + try: + communications = ( + DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] + ) + except ValueError: + ez.logger.error( + f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}" + ) + return + + try: + configurators = ( + list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs] + ) + except ValueError: + ez.logger.error( + f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}" + ) + return + + subitr = itertools.product if grid else zip + + test_list = [ + (msg_size, clients, conf, comm) + for msg_size, clients in subitr(msg_sizes, n_clients) + for conf, comm in itertools.product(configurators, communications) + ] * iters + + random.shuffle(test_list) + + server = GraphServer() + server.start() + + ez.logger.info( + f"About to run {len(test_list)} tests (repeated {repeats} times) of {max_duration} sec (max) each." + ) + ez.logger.info( + f"During each test, source will attempt to send {num_msgs} messages to the sink." + ) + ez.logger.info( + "Please try to avoid running other taxing software while this perf test runs." + ) + ez.logger.info( + "NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early." + ) + + quitting = False + + start_time = time.time() + + try: + ez.logger.info(f"Warming up for {warmup_dur} seconds...") + warmup(warmup_dur) + + with open(f"perf_{get_datestamp()}.txt", "w") as out_f: + for _ in range(repeats): + out_f.write( + json.dumps(TestEnvironmentInfo(), cls=MessageEncoder) + "\n" + ) + + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): + if CHECK_FOR_QUIT(): + ez.logger.info("Stopping tests early...") + quitting = True + break + + ez.logger.info( + f"TEST {test_idx + 1}/{len(test_list)}: " + f"{clients=}, {msg_size=}, conf={conf.__name__}, " + f"comm={comm.value}" + ) + + output = TestLogEntry( + params=TestParameters( + msg_size=msg_size, + num_msgs=num_msgs, + n_clients=clients, + config=conf.__name__, + comms=comm.value, + max_duration=max_duration, + num_buffers=num_buffers, + ), + results=perform_test( + n_clients=clients, + max_duration=max_duration, + num_msgs=num_msgs, + msg_size=msg_size, + buffers=num_buffers, + comms=comm, + config=conf, + graph_address=server.address, + ), + ) + + out_f.write(json.dumps(output, cls=MessageEncoder) + "\n") + + if quitting: + break + + finally: + server.stop() + d = datetime(1, 1, 1) + timedelta(seconds=time.time() - start_time) + dur_str = ":".join( + [str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0] + ) + ez.logger.info(f"Tests concluded. Wallclock Runtime: {dur_str}s") + + +def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_run = subparsers.add_parser("run", help="run performance test") + + p_run.add_argument( + "--max-duration", + type=float, + default=5.0, + help="maximum individual test duration in seconds (default = 5.0)", + ) + + p_run.add_argument( + "--num-msgs", + type=int, + default=1000, + help="number of messages to send per-test (default = 1000)", + ) + + # NOTE: We default num-buffers = 1 because this degenerate perf test scenario (blasting + # messages as fast as possible through the system) results in one of two scenerios: + # 1. A (few) messages is/are enqueued and dequeued before another message is posted + # 2. The buffer fills up before being FULLY emptied resulting in longer latency. + # (once a channel enters this condition, it tends to stay in this condition) + # + # This _indeterminate_ behavior results in bimodal distributions of runtimes that make + # A/B performance comparisons difficult. The perf test is not representative of the vast + # majority of production ezmsg systems where publishing is generally rate-limited. + # + # A flow-control algorithm could stabilize perf-test results with num_buffers > 1, but is + # generally implemented by enforcing delays on the publish side which simply degrades + # performance in the vast majority of ezmsg systems. - Griff + p_run.add_argument( + "--num-buffers", + type=int, + default=1, + help="shared memory buffers (default = 1)", + ) + + p_run.add_argument( + "--iters", + "-i", + type=int, + default=5, + help="number of times to run each test (default = 5)", + ) + + p_run.add_argument( + "--repeats", + "-r", + type=int, + default=10, + help="number of times to repeat the perf (default = 10)", + ) + + p_run.add_argument( + "--msg-sizes", + type=int, + default=None, + nargs="*", + help=f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})", + ) + + p_run.add_argument( + "--n-clients", + type=int, + default=None, + nargs="*", + help=f"number of clients (default = {DEFAULT_N_CLIENTS})", + ) + + p_run.add_argument( + "--comms", + type=str, + default=None, + nargs="*", + help=f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})", + ) + + p_run.add_argument( + "--configs", + type=str, + default=None, + nargs="*", + help=f"configurations to test (default = {[c for c in CONFIGS]})", + ) + + p_run.add_argument( + "--warmup", + type=float, + default=60.0, + help="warmup CPU with busy task for some number of seconds (default = 60.0)", + ) + + p_run.set_defaults( + _handler=lambda ns: perf_run( + max_duration=ns.max_duration, + num_msgs=ns.num_msgs, + num_buffers=ns.num_buffers, + iters=ns.iters, + repeats=ns.repeats, + msg_sizes=ns.msg_sizes, + n_clients=ns.n_clients, + comms=ns.comms, + configs=ns.configs, + grid=True, + warmup_dur=ns.warmup, + ) + ) diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py new file mode 100644 index 00000000..eb79562f --- /dev/null +++ b/src/ezmsg/util/perf/util.py @@ -0,0 +1,278 @@ +import os +import sys +import gc +import time +import statistics as stats +import contextlib +import subprocess +from dataclasses import dataclass +from typing import Iterable + +try: + import psutil # optional but helpful +except Exception: + psutil = None + +_IS_WIN = os.name == "nt" +_IS_MAC = sys.platform == "darwin" +_IS_LINUX = sys.platform.startswith("linux") + +# ---------- Utilities ---------- + + +def _set_env_threads(single_thread: bool = True): + """ + Normalize math/threading libs so they don't spawn surprise worker threads. + """ + if single_thread: + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1") + os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") + os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") + # Keep PYTHONHASHSEED stable for deterministic dict/set iteration costs + os.environ.setdefault("PYTHONHASHSEED", "0") + + +# ---------- Priority & Affinity ---------- + + +@contextlib.contextmanager +def _process_priority(): + """ + Elevate process priority in a cross-platform best-effort way. + """ + if psutil is None: + yield + return + + p = psutil.Process() + orig_nice = None + if _IS_WIN: + try: + import ctypes + + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + ABOVE_NORMAL_PRIORITY_CLASS = 0x00008000 + HIGH_PRIORITY_CLASS = 0x00000080 + # Try High, fall back to Above Normal + if not kernel32.SetPriorityClass( + kernel32.GetCurrentProcess(), HIGH_PRIORITY_CLASS + ): + kernel32.SetPriorityClass( + kernel32.GetCurrentProcess(), ABOVE_NORMAL_PRIORITY_CLASS + ) + except Exception: + pass + else: + try: + orig_nice = p.nice() + # Negative nice may need privileges; try smaller magnitude first + for nice_val in (-10, -5, 0): + try: + p.nice(nice_val) + break + except Exception: + continue + except Exception: + pass + try: + yield + finally: + # restore nice if we changed it + if psutil is not None and not _IS_WIN and orig_nice is not None: + try: + p.nice(orig_nice) + except Exception: + pass + + +@contextlib.contextmanager +def _cpu_affinity(prefer_isolation: bool = True): + """ + Set CPU affinity to a small, stable set of CPUs (where supported). + macOS does not support affinity via psutil; we no-op there. + """ + if psutil is None or _IS_MAC: + yield + return + + p = psutil.Process() + original = None + try: + if hasattr(p, "cpu_affinity"): + original = p.cpu_affinity() + cpus = original + if prefer_isolation and len(cpus) > 2: + # Pick two middle CPUs to avoid 0 which often handles interrupts + mid = len(cpus) // 2 + cpus = [cpus[mid - 1], cpus[mid]] + p.cpu_affinity(cpus) + yield + finally: + try: + if original is not None and hasattr(p, "cpu_affinity"): + p.cpu_affinity(original) + except Exception: + pass + + +# ---------- Platform-specific helpers ---------- + + +@contextlib.contextmanager +def _mac_caffeinate(): + """ + Keep macOS awake during the run via a background caffeinate process. + """ + if not _IS_MAC: + yield + return + proc = None + try: + proc = subprocess.Popen(["caffeinate", "-dimsu"]) + except Exception: + proc = None + try: + yield + finally: + if proc is not None: + try: + proc.terminate() + except Exception: + pass + + +@contextlib.contextmanager +def _win_timer_resolution(ms: int = 1): + """ + On Windows, request a finer system timer to stabilize sleeps and scheduling slices. + """ + if not _IS_WIN: + yield + return + import ctypes + + winmm = ctypes.WinDLL("winmm") + timeBeginPeriod = winmm.timeBeginPeriod + timeEndPeriod = winmm.timeEndPeriod + try: + timeBeginPeriod(ms) + except Exception: + pass + try: + yield + finally: + try: + timeEndPeriod(ms) + except Exception: + pass + + +# ---------- Warm-up & GC ---------- + + +def warmup(seconds: float = 60.0, fn=None, *args, **kwargs): + """ + Optional warm-up to reach steady clocks/caches. + If fn is provided, call it in a loop for the given time. + """ + if seconds <= 0: + return + end = time.perf_counter() + target = end + seconds + if fn is None: + # Busy wait / sleep mix to heat up without heavy CPU + while time.perf_counter() < target: + x = 0 + for _ in range(10000): + x += 1 + time.sleep(0) + else: + while time.perf_counter() < target: + fn(*args, **kwargs) + + +@contextlib.contextmanager +def gc_pause(): + """ + Disable GC inside timing windows; re-enable and collect after. + """ + was_enabled = gc.isenabled() + try: + gc.disable() + yield + finally: + if was_enabled: + gc.enable() + gc.collect() + + +# ---------- Robust statistics ---------- + + +def median_of_means(samples: Iterable[float], k: int = 5) -> float: + """ + Robust estimate: split samples into k buckets (round-robin), average each, take median of bucket means. + """ + samples = list(samples) + if not samples: + return float("nan") + k = max(1, min(k, len(samples))) + buckets = [[] for _ in range(k)] + for i, v in enumerate(samples): + buckets[i % k].append(v) + means = [sum(b) / len(b) for b in buckets if b] + means.sort() + return means[len(means) // 2] + + +def coef_var(samples: Iterable[float]) -> float: + vals = list(samples) + if len(vals) < 2: + return 0.0 + m = sum(vals) / len(vals) + if m == 0: + return 0.0 + sd = stats.pstdev(vals) + return sd / m + + +# ---------- Public context manager ---------- + + +@dataclass +class PerfOptions: + single_thread_math: bool = True + prefer_isolated_cpus: bool = True + warmup_seconds: float = 0.0 + adjust_priority: bool = True + tweak_timer_windows: bool = True + keep_mac_awake: bool = True + + +@contextlib.contextmanager +def stable_perf(opts: PerfOptions = PerfOptions()): + """ + Wrap your perf runs with this context manager for a stabler environment. + """ + _set_env_threads(opts.single_thread_math) + + cm_stack = contextlib.ExitStack() + try: + if opts.adjust_priority: + cm_stack.enter_context(_process_priority()) + if opts.tweak_timer_windows: + cm_stack.enter_context(_win_timer_resolution(1)) + if opts.prefer_isolated_cpus: + cm_stack.enter_context(_cpu_affinity(True)) + if opts.keep_mac_awake: + cm_stack.enter_context(_mac_caffeinate()) + + if opts.warmup_seconds > 0: + warmup(opts.warmup_seconds) + + with gc_pause(): + yield + finally: + cm_stack.close() diff --git a/src/ezmsg/util/perf_test.py b/src/ezmsg/util/perf_test.py deleted file mode 100644 index b0a5f66f..00000000 --- a/src/ezmsg/util/perf_test.py +++ /dev/null @@ -1,220 +0,0 @@ -import asyncio -import dataclasses -import datetime -import os -import platform -import time - -from collections.abc import AsyncGenerator - -import ezmsg.core as ez - -# We expect this test to generate LOTS of backpressure warnings -# PERF_LOGLEVEL = os.environ.get("EZMSG_LOGLEVEL", "ERROR") -# ez.logger.setLevel(PERF_LOGLEVEL) - -PLATFORM = { - "Darwin": "mac", - "Linux": "linux", - "Windows": "win", -}[platform.system()] -SAMPLE_SUMMARY_DATASET_PREFIX = "sample_summary" -COUNT_DATASET_NAME = "count" - - -try: - import numpy as np -except ImportError: - ez.logger.error("This test requires Numpy to run.") - raise - - -def get_datestamp() -> str: - return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - -class LoadTestSettings(ez.Settings): - duration: float = 30.0 - dynamic_size: int = 8 - buffers: int = 32 - - -@dataclasses.dataclass -class LoadTestSample: - _timestamp: float - counter: int - dynamic_data: np.ndarray - - -class LoadTestPublisher(ez.Unit): - OUTPUT = ez.OutputStream(LoadTestSample) - SETTINGS = LoadTestSettings - - async def initialize(self) -> None: - self.running = True - self.counter = 0 - self.OUTPUT.num_buffers = self.SETTINGS.buffers - - @ez.publisher(OUTPUT) - async def publish(self) -> AsyncGenerator: - ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.time() - while self.running: - current_time = time.time() - if current_time - start_time >= self.SETTINGS.duration: - break - - yield ( - self.OUTPUT, - LoadTestSample( - _timestamp=time.time(), - counter=self.counter, - dynamic_data=np.zeros( - int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 - ), - ), - ) - self.counter += 1 - ez.logger.info("Exiting publish") - raise ez.Complete - - async def shutdown(self) -> None: - self.running = False - ez.logger.info(f"Samples sent: {self.counter}") - - -class LoadTestSubscriberState(ez.State): - # Tuples of sent timestamp, received timestamp, counter, dynamic size - received_data: list[tuple[float, float, int]] = dataclasses.field( - default_factory=list - ) - counter: int = -1 - - -class LoadTestSubscriber(ez.Unit): - INPUT = ez.InputStream(LoadTestSample) - SETTINGS = LoadTestSettings - STATE = LoadTestSubscriberState - - @ez.subscriber(INPUT, zero_copy=True) - async def receive(self, sample: LoadTestSample) -> None: - if sample.counter != self.STATE.counter + 1: - ez.logger.warning( - f"{sample.counter - self.STATE.counter - 1} samples skipped!" - ) - self.STATE.received_data.append( - (sample._timestamp, time.time(), sample.counter) - ) - self.STATE.counter = sample.counter - - @ez.task - async def log_result(self) -> None: - ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") - - # Wait for the duration of the load test - await asyncio.sleep(self.SETTINGS.duration) - # logger.info(f"STATE = {self.STATE.received_data}") - - # Log some useful summary statistics - min_timestamp = min(timestamp for timestamp, _, _ in self.STATE.received_data) - max_timestamp = max(timestamp for timestamp, _, _ in self.STATE.received_data) - total_latency = abs( - sum( - receive_timestamp - send_timestamp - for send_timestamp, receive_timestamp, _ in self.STATE.received_data - ) - ) - - counters = list(sorted(t[2] for t in self.STATE.received_data)) - dropped_samples = sum( - [(x1 - x0) - 1 for x1, x0 in zip(counters[1:], counters[:-1])] - ) - - num_samples = len(self.STATE.received_data) - ez.logger.info(f"Samples received: {num_samples}") - ez.logger.info( - f"Sample rate: {num_samples / (max_timestamp - min_timestamp)} Hz" - ) - ez.logger.info(f"Mean latency: {total_latency / num_samples} s") - ez.logger.info(f"Total latency: {total_latency} s") - - total_data = num_samples * self.SETTINGS.dynamic_size - ez.logger.info( - f"Data rate: {total_data / (max_timestamp - min_timestamp) * 1e-6} MB/s" - ) - ez.logger.info( - f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", - ) - - raise ez.NormalTermination - - -class LoadTest(ez.Collection): - SETTINGS = LoadTestSettings - - PUBLISHER = LoadTestPublisher() - SUBSCRIBER = LoadTestSubscriber() - - def configure(self) -> None: - self.PUBLISHER.apply_settings(self.SETTINGS) - self.SUBSCRIBER.apply_settings(self.SETTINGS) - - def network(self) -> ez.NetworkDefinition: - return ((self.PUBLISHER.OUTPUT, self.SUBSCRIBER.INPUT),) - - def process_components(self): - return ( - self.PUBLISHER, - self.SUBSCRIBER, - ) - - -def get_time() -> float: - # time.perf_counter() isn't system-wide on Windows Python 3.6: - # https://bugs.python.org/issue37205 - return time.time() if PLATFORM == "win" else time.perf_counter() - - -def test_performance(duration, size, buffers) -> None: - ez.logger.info(f"Running load test for dynamic size: {size} bytes") - system = LoadTest( - LoadTestSettings(dynamic_size=int(size), duration=duration, buffers=buffers) - ) - ez.run(SYSTEM=system) - - -def run_many_dynamic_sizes(duration, buffers) -> None: - for exp in range(5, 22, 4): - test_performance(duration, 2**exp, buffers) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--many-dynamic-sizes", - action="store_true", - help="Run load test for many dynamic sizes", - ) - parser.add_argument( - "--duration", - type=int, - default=2, - help="How long to run the load test (seconds)", - ) - parser.add_argument( - "--num-buffers", type=int, default=32, help="Shared memory buffers" - ) - - class Args: - many_dynamic_sizes: bool - duration: int - num_buffers: int - - args = parser.parse_args(namespace=Args) - - if args.many_dynamic_sizes: - run_many_dynamic_sizes(args.duration, args.num_buffers) - else: - test_performance(args.duration, 8, args.num_buffers) diff --git a/src/ezmsg/util/rate.py b/src/ezmsg/util/rate.py index 91452eba..e8e51afb 100644 --- a/src/ezmsg/util/rate.py +++ b/src/ezmsg/util/rate.py @@ -30,7 +30,7 @@ def __init__(self, hz: float): def _remaining(self, curr_time: float) -> float: """ Calculate the time remaining for rate to sleep. - + :param curr_time: Current time :type curr_time: float :return: Time remaining in seconds @@ -47,7 +47,7 @@ def _remaining(self, curr_time: float) -> float: def remaining(self) -> float: """ Return the time remaining for rate to sleep. - + :return: Time remaining in seconds :rtype: float """ @@ -57,7 +57,7 @@ def remaining(self) -> float: def _sleep_logic(self) -> float: """ Internal method to calculate sleep time and update timing state. - + :return: Time to sleep in seconds (non-negative) :rtype: float """ diff --git a/tests/ez_test_utils.py b/tests/ez_test_utils.py index a27c790e..6ac7134f 100644 --- a/tests/ez_test_utils.py +++ b/tests/ez_test_utils.py @@ -1,5 +1,7 @@ from dataclasses import asdict, dataclass from collections.abc import AsyncGenerator +from contextlib import contextmanager + import json import os from pathlib import Path @@ -9,7 +11,8 @@ import ezmsg.core as ez -def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path: +@contextmanager +def get_test_fn(test_name: str | None = None, extension: str = "txt") -> typing.Generator[Path, None, None]: """PYTEST compatible temporary test file creator""" # Get current test name if we can.. @@ -20,14 +23,20 @@ def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path: else: test_name = __name__ - file_path = Path(tempfile.gettempdir()) - file_path = file_path / Path(f"{test_name}.{extension}") - - # Create the file - with open(file_path, "w"): - pass - - return file_path + # Create a unique temporary file name to avoid collisions when running the + # full test suite in parallel or when other tests use the same test name. + # Use NamedTemporaryFile with delete=False so callers can open/remove it. + prefix = f"{test_name}-" if test_name else "test-" + tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=f".{extension}") + tmp.close() # Close so others can open it on Windows + path = Path(tmp.name) + try: + yield path + finally: + try: + path.unlink() + except FileNotFoundError: + pass # MESSAGE DEFINITIONS diff --git a/tests/messages/test_axisarray.py b/tests/messages/test_axisarray.py index 6dfa79b0..e11bc474 100644 --- a/tests/messages/test_axisarray.py +++ b/tests/messages/test_axisarray.py @@ -1,3 +1,4 @@ +import importlib.util import pytest import numpy as np @@ -274,15 +275,8 @@ def test_sliding_win_oneaxis(nwin: int, axis: int, step: int): assert np.shares_memory(res, expected) -def xarray_available(): - try: - import xarray - - return True - except ImportError: - return False - except ValueError: - return False +def xarray_available() -> bool: + return importlib.util.find_spec("xarray") is not None @pytest.mark.skipif( diff --git a/tests/messages/test_key.py b/tests/messages/test_key.py index d5222680..aa44ee60 100644 --- a/tests/messages/test_key.py +++ b/tests/messages/test_key.py @@ -2,7 +2,6 @@ import copy import json import os -import typing import numpy as np @@ -75,71 +74,69 @@ async def on_message(self, msg: AxisArray) -> None: def test_set_key_unit(): num_msgs = 20 - test_fn_raw = get_test_fn() - test_fn_keyed = test_fn_raw.parent / ( - test_fn_raw.stem + "_keyed" + test_fn_raw.suffix - ) - - comps = { - "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), - "KEYSETTER": SetKey(key="new key"), - "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), - "SINK_KEYED": AxarrReceiver(num_msgs=num_msgs, output_fn=test_fn_keyed), - } - conns = [ - (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), - (comps["SRC"].OUTPUT_SIGNAL, comps["KEYSETTER"].INPUT_SIGNAL), - (comps["KEYSETTER"].OUTPUT_SIGNAL, comps["SINK_KEYED"].INPUT_SIGNAL), - ] - ez.run(components=comps, connections=conns) - - with open(test_fn_raw, "r") as file: - raw_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_raw) - assert len(raw_results) == num_msgs - assert all( - [ - _[str(ix + 1)] == ("odd" if ix % 2 else "even") - for ix, _ in enumerate(raw_results) + with get_test_fn() as test_fn_raw: + test_fn_keyed = test_fn_raw.parent / ( + test_fn_raw.stem + "_keyed" + test_fn_raw.suffix + ) + + comps = { + "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), + "KEYSETTER": SetKey(key="new key"), + "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), + "SINK_KEYED": AxarrReceiver(num_msgs=num_msgs, output_fn=test_fn_keyed), + } + conns = [ + (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), + (comps["SRC"].OUTPUT_SIGNAL, comps["KEYSETTER"].INPUT_SIGNAL), + (comps["KEYSETTER"].OUTPUT_SIGNAL, comps["SINK_KEYED"].INPUT_SIGNAL), ] - ) + ez.run(components=comps, connections=conns) + + with open(test_fn_raw, "r") as file: + raw_results = [json.loads(_) for _ in file.readlines()] + assert len(raw_results) == num_msgs + assert all( + [ + _[str(ix + 1)] == ("odd" if ix % 2 else "even") + for ix, _ in enumerate(raw_results) + ] + ) - with open(test_fn_keyed, "r") as file: - keyed_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_keyed) - assert len(keyed_results) == num_msgs - assert all([_[str(ix + 1)] == "new key" for ix, _ in enumerate(keyed_results)]) + with open(test_fn_keyed, "r") as file: + keyed_results = [json.loads(_) for _ in file.readlines()] + os.remove(test_fn_keyed) + assert len(keyed_results) == num_msgs + assert all([_[str(ix + 1)] == "new key" for ix, _ in enumerate(keyed_results)]) def test_filter_key(): num_msgs = 20 - test_fn_raw = get_test_fn() - test_fn_filtered = test_fn_raw.parent / ( - test_fn_raw.stem + "_keyed" + test_fn_raw.suffix - ) + with get_test_fn() as test_fn_raw: + test_fn_filtered = test_fn_raw.parent / ( + test_fn_raw.stem + "_keyed" + test_fn_raw.suffix + ) + + comps = { + "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), + "FILTER": FilterOnKey(key="odd"), + "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), + "SINK_FILTERED": AxarrReceiver( + num_msgs=num_msgs // 2, output_fn=test_fn_filtered + ), + } + conns = [ + (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), + (comps["SRC"].OUTPUT_SIGNAL, comps["FILTER"].INPUT_SIGNAL), + (comps["FILTER"].OUTPUT_SIGNAL, comps["SINK_FILTERED"].INPUT_SIGNAL), + ] + ez.run(components=comps, connections=conns) + + with open(test_fn_raw, "r") as file: + raw_results = [json.loads(_) for _ in file.readlines()] + assert len(raw_results) == num_msgs - comps = { - "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), - "FILTER": FilterOnKey(key="odd"), - "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), - "SINK_FILTERED": AxarrReceiver( - num_msgs=num_msgs // 2, output_fn=test_fn_filtered - ), - } - conns = [ - (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), - (comps["SRC"].OUTPUT_SIGNAL, comps["FILTER"].INPUT_SIGNAL), - (comps["FILTER"].OUTPUT_SIGNAL, comps["SINK_FILTERED"].INPUT_SIGNAL), - ] - ez.run(components=comps, connections=conns) - - with open(test_fn_raw, "r") as file: - raw_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_raw) - assert len(raw_results) == num_msgs - - with open(test_fn_filtered, "r") as file: - filtered_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_filtered) - assert len(filtered_results) == num_msgs // 2 - assert all([_[str(ix + 1)] == "odd" for ix, _ in enumerate(filtered_results)]) + with open(test_fn_filtered, "r") as file: + filtered_results = [json.loads(_) for _ in file.readlines()] + os.remove(test_fn_filtered) + assert len(filtered_results) == num_msgs // 2 + assert all([_[str(ix + 1)] == "odd" for ix, _ in enumerate(filtered_results)]) diff --git a/tests/messages/test_modify.py b/tests/messages/test_modify.py index 277db8ad..91f3c708 100644 --- a/tests/messages/test_modify.py +++ b/tests/messages/test_modify.py @@ -1,5 +1,4 @@ import copy -import typing import numpy as np import pytest diff --git a/tests/test_channel.py b/tests/test_channel.py new file mode 100644 index 00000000..e45f7192 --- /dev/null +++ b/tests/test_channel.py @@ -0,0 +1,93 @@ +import asyncio +from uuid import uuid4 + +import pytest + +from ezmsg.core.messagechannel import Channel +from ezmsg.core.messagecache import CacheMiss +from ezmsg.core.netprotocol import Command, uint64_to_bytes +from ezmsg.core.backpressure import Backpressure + + +class DummyWriter: + def __init__(self): + self.buffer: list[bytes] = [] + + def write(self, data: bytes) -> None: + self.buffer.append(data) + + +def _resolved_task(): + loop = asyncio.get_running_loop() + fut = loop.create_future() + fut.set_result(None) + return fut + + +@pytest.mark.asyncio +async def test_channel_acknowledges_remote_messages(): + channel = Channel(uuid4(), uuid4(), 2, None, None, Channel._SENTINEL) + channel._pub_writer = DummyWriter() + channel._pub_task = _resolved_task() + channel._graph_task = _resolved_task() + + client_id = uuid4() + queue: asyncio.Queue = asyncio.Queue() + channel.register_client(client_id, queue) + + msg_id = 5 + payload = {"value": 42} + channel.cache.put_local(payload, msg_id) + channel._notify_clients(msg_id) + + assert queue.qsize() == 1 + queued_pub, queued_msg = queue.get_nowait() + assert queued_pub == channel.pub_id + assert queued_msg == msg_id + + with channel.get(msg_id, client_id) as obj: + assert obj == payload + + with pytest.raises(CacheMiss): + _ = channel.cache[msg_id] + + buf_idx = msg_id % channel.num_buffers + assert channel.backpressure.buffers[buf_idx].is_empty + + expected_ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) + assert channel._pub_writer.buffer[-1] == expected_ack + + +@pytest.mark.asyncio +async def test_channel_releases_local_backpressure(monkeypatch): + channel = Channel(uuid4(), uuid4(), 2, None, None, Channel._SENTINEL) + channel._pub_writer = DummyWriter() + channel._pub_task = _resolved_task() + channel._graph_task = _resolved_task() + + local_bp = Backpressure(channel.num_buffers) + channel.register_client(channel.pub_id, None, local_bp) + + client_id = uuid4() + queue: asyncio.Queue = asyncio.Queue() + channel.register_client(client_id, queue) + + msg_id = 3 + payload = "local" + channel.put_local(msg_id, payload) + + assert queue.qsize() == 1 + queue.get_nowait() + + with channel.get(msg_id, client_id) as obj: + assert obj == payload + + buf_idx = msg_id % channel.num_buffers + assert local_bp.buffers[buf_idx].is_empty + assert channel._pub_writer.buffer == [] + + +def test_channel_put_local_requires_local_backpressure(): + channel = Channel(uuid4(), uuid4(), 1, None, None, Channel._SENTINEL) + with pytest.raises(ValueError): + channel.put_local(1, "no pub") diff --git a/tests/test_channelmanager.py b/tests/test_channelmanager.py new file mode 100644 index 00000000..0f78a6f8 --- /dev/null +++ b/tests/test_channelmanager.py @@ -0,0 +1,105 @@ +import asyncio +from dataclasses import dataclass +from uuid import uuid4 + +import pytest + +from ezmsg.core.channelmanager import ChannelManager +from ezmsg.core.backpressure import Backpressure +from ezmsg.core.netprotocol import Address +from ezmsg.core import channelmanager as channelmanager_module + + +@dataclass +class DummyChannel: + clients: dict + + def __init__(self): + self.clients = {} + self.closed = False + self.waited = False + self.local_bp: dict = {} + + def register_client(self, client_id, queue, local_backpressure): + self.clients[client_id] = queue + if local_backpressure is not None: + self.local_bp[client_id] = local_backpressure + + def unregister_client(self, client_id): + del self.clients[client_id] + + def close(self): + self.closed = True + + async def wait_closed(self): + self.waited = True + + +@pytest.mark.asyncio +async def test_channel_manager_reuses_existing_channel(monkeypatch): + dummy_channel = DummyChannel() + + async def fake_create(pub_id, address): + return dummy_channel + + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) + + manager = ChannelManager() + pub_id = uuid4() + client_one = uuid4() + client_two = uuid4() + + queue_one: asyncio.Queue = asyncio.Queue() + queue_two: asyncio.Queue = asyncio.Queue() + + channel_a = await manager.register(pub_id, client_one, queue_one) + channel_b = await manager.register(pub_id, client_two, queue_two) + + assert channel_a is dummy_channel + assert channel_b is dummy_channel + assert dummy_channel.clients[client_one] is queue_one + assert dummy_channel.clients[client_two] is queue_two + + +@pytest.mark.asyncio +async def test_channel_manager_unregister_closes_channel(monkeypatch): + dummy_channel = DummyChannel() + + async def fake_create(pub_id, address): + dummy_channel.address = address + return dummy_channel + + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) + + manager = ChannelManager() + pub_id = uuid4() + client_id = uuid4() + + queue: asyncio.Queue = asyncio.Queue() + await manager.register(pub_id, client_id, queue) + await manager.unregister(pub_id, client_id) + + default_address = Address.from_string(channelmanager_module.GRAPHSERVER_ADDR) + assert pub_id not in manager._registry[default_address] + assert dummy_channel.closed + assert dummy_channel.waited + + +@pytest.mark.asyncio +async def test_channel_manager_registers_local_publisher(monkeypatch): + dummy_channel = DummyChannel() + + async def fake_create(pub_id, address): + return dummy_channel + + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) + + manager = ChannelManager() + pub_id = uuid4() + local_bp = Backpressure(1) + + channel = await manager.register_local_pub(pub_id, local_bp) + + assert channel is dummy_channel + assert dummy_channel.clients[pub_id] is None + assert dummy_channel.local_bp[pub_id] is local_bp diff --git a/tests/test_generator.py b/tests/test_generator.py index 37921cc9..df013006 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -163,25 +163,25 @@ def test_gen_to_unit_any(): assert MyUnit.OUTPUT.msg_type == list[typing.Any] num_msgs = 5 - test_filename = get_test_fn() - comps = { - "SIMPLE_PUB": MessageGenerator(num_msgs=num_msgs), - "MYUNIT": MyUnit(), - "SIMPLE_SUB": MessageAnyReceiver(num_msgs=num_msgs, output_fn=test_filename), - } - ez.run( - components=comps, - connections=( - (comps["SIMPLE_PUB"].OUTPUT, comps["MYUNIT"].INPUT), - (comps["MYUNIT"].OUTPUT, comps["SIMPLE_SUB"].INPUT), - ), - ) - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) + with get_test_fn() as test_filename: + comps = { + "SIMPLE_PUB": MessageGenerator(num_msgs=num_msgs), + "MYUNIT": MyUnit(), + "SIMPLE_SUB": MessageAnyReceiver(num_msgs=num_msgs, output_fn=test_filename), + } + ez.run( + components=comps, + connections=( + (comps["SIMPLE_PUB"].OUTPUT, comps["MYUNIT"].INPUT), + (comps["MYUNIT"].OUTPUT, comps["SIMPLE_SUB"].INPUT), + ), + ) + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + # We don't really care about the contents; functionality was confirmed in a separate test. # Keep this simple. assert len(results) == num_msgs @@ -224,25 +224,26 @@ def test_gen_to_unit_axarr(): assert MyUnit.OUTPUT_SIGNAL.msg_type is AxisArray num_msgs = 5 - test_filename = get_test_fn() - comps = { - "SIMPLE_PUB": AxarrGenerator(num_msgs=num_msgs), - "MYUNIT": MyUnit(), - "SIMPLE_SUB": AxarrReceiver(num_msgs=num_msgs, output_fn=test_filename), - } - ez.run( - components=comps, - connections=( - (comps["SIMPLE_PUB"].OUTPUT_SIGNAL, comps["MYUNIT"].INPUT_SIGNAL), - (comps["MYUNIT"].OUTPUT_SIGNAL, comps["SIMPLE_SUB"].INPUT_SIGNAL), - ), - ) - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) + + with get_test_fn() as test_filename: + comps = { + "SIMPLE_PUB": AxarrGenerator(num_msgs=num_msgs), + "MYUNIT": MyUnit(), + "SIMPLE_SUB": AxarrReceiver(num_msgs=num_msgs, output_fn=test_filename), + } + ez.run( + components=comps, + connections=( + (comps["SIMPLE_PUB"].OUTPUT_SIGNAL, comps["MYUNIT"].INPUT_SIGNAL), + (comps["MYUNIT"].OUTPUT_SIGNAL, comps["SIMPLE_SUB"].INPUT_SIGNAL), + ), + ) + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + assert np.array_equal( results[-1][f"{num_msgs}"], np.hstack([np.arange(_) for _ in range(num_msgs)]) ) diff --git a/tests/test_graph.py b/tests/test_graph.py index b7b414ac..5da45fb1 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -19,6 +19,28 @@ simple_graph_2 = [("w", "x"), ("w", "y"), ("x", "z"), ("y", "z")] +@pytest.mark.asyncio +async def test_pub_first(): + async with GraphContext() as context: + await context.publisher("a") + await asyncio.sleep(0.1) + await context.connect("a", "b") + await asyncio.sleep(0.1) + await context.subscriber("b") + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_sub_first(): + async with GraphContext() as context: + await context.subscriber("b") + await asyncio.sleep(0.1) + await context.connect("a", "b") + await asyncio.sleep(0.1) + await context.publisher("a") + await asyncio.sleep(0.1) + + @pytest.mark.asyncio async def test_graph(): async with GraphContext() as context: @@ -28,9 +50,13 @@ async def test_graph(): for edge in simple_graph_1: await context.connect(*edge) + await asyncio.sleep(0.1) + await context.subscriber("c") await context.publisher("b") + await asyncio.sleep(0.1) + for edge in simple_graph_2: await context.connect(*edge) @@ -53,6 +79,18 @@ async def test_graph(): await context.connect("c", "a") +@pytest.mark.asyncio +async def test_comms_simple(): + async with GraphContext() as context: + b_sub = await context.subscriber("b") + await context.connect("a", "b") + a_pub = await context.publisher("a") + await a_pub.broadcast("HELLO") + print("DONE BROADCASTING") + msg = await b_sub.recv() + assert msg == "HELLO" + + @pytest.mark.asyncio async def test_comms(): async with GraphContext() as context: @@ -192,8 +230,8 @@ async def run_async(self) -> None: if start is not None: delta = time.perf_counter() - start message_rate = int(self.n_msgs / delta) - print(f"{ message_rate } msgs/sec") - print(f"{ msg_size * message_rate / 1024 / 1024 / 1024 } GB/sec") + print(f"{message_rate} msgs/sec") + print(f"{msg_size * message_rate / 1024 / 1024 / 1024} GB/sec") await asyncio.get_running_loop().run_in_executor( None, self.stop_barrier.wait diff --git a/tests/test_perf_configs.py b/tests/test_perf_configs.py new file mode 100644 index 00000000..fc7178aa --- /dev/null +++ b/tests/test_perf_configs.py @@ -0,0 +1,99 @@ +import contextlib +import os +import tempfile +from pathlib import Path + +import pytest + +from ezmsg.core.graphserver import GraphServer +from ezmsg.util.perf.impl import Communication, CONFIGS, perform_test + + +PERF_MAX_DURATION = 0.5 +PERF_NUM_MSGS = 8 +PERF_MSG_SIZES = [64, 2**20] +PERF_NUM_BUFFERS = 2 +CLIENTS_PER_CONFIG = { + "fanout": 2, + "fanin": 2, + "relay": 2, +} + + +def _run_perf_case( + config_name: str, + comm: Communication, + msg_size: int, + server: GraphServer, +) -> None: + metrics = perform_test( + n_clients=CLIENTS_PER_CONFIG[config_name], + max_duration=PERF_MAX_DURATION, + num_msgs=PERF_NUM_MSGS, + msg_size=msg_size, + buffers=PERF_NUM_BUFFERS, + comms=comm, + config=CONFIGS[config_name], + graph_address=server.address, + ) + assert metrics.num_msgs > 0, ( + f"Failed to exchange messages for {config_name}/{comm.value}/msg={msg_size}" + ) + + +@contextlib.contextmanager +def _file_lock(path: Path): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w+") as lock_file: + lock_file.write("0") + lock_file.flush() + if os.name == "nt": + import msvcrt + + msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) + else: + import fcntl + + fcntl.flock(lock_file, fcntl.LOCK_EX) + try: + yield + finally: + if os.name == "nt": + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lock_file, fcntl.LOCK_UN) + + +@pytest.fixture(scope="module") +def perf_graph_server(): + """ + Spin up a dedicated graph server on an ephemeral port for the perf smoke tests. + The shared instance keeps startup/shutdown overhead down while isolating it + from the canonical graph server that other tests rely on. + """ + lock_path = Path(tempfile.gettempdir()) / "ezmsg_perf_smoke.lock" + with _file_lock(lock_path): + server = GraphServer() + server.start() + try: + yield server + finally: + server.stop() + + +@pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") +@pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") +def test_fanout_perf(perf_graph_server, comm, msg_size): + _run_perf_case("fanout", comm, msg_size, perf_graph_server) + + +@pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") +@pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") +def test_fanin_perf(perf_graph_server, comm, msg_size): + _run_perf_case("fanin", comm, msg_size, perf_graph_server) + + +@pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") +@pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") +def test_relay_perf(perf_graph_server, comm, msg_size): + _run_perf_case("relay", comm, msg_size, perf_graph_server) diff --git a/tests/test_run.py b/tests/test_run.py index bc5046c2..323bc4ba 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -63,70 +63,69 @@ def func(self): @pytest.mark.parametrize("num_messages", [1, 5, 10]) def test_local_system(toy_system_fixture, num_messages): - test_filename = get_test_fn() - system = toy_system_fixture( - ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) - ) - ez.run(SYSTEM=system, profiler_log_name="test_profiler.log") - assert os.environ.get("EZMSG_PROFILER") == "test_profiler.log" - - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) - assert len(results) == num_messages + with get_test_fn() as test_filename: + system = toy_system_fixture( + ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) + ) + ez.run(SYSTEM=system, profiler_log_name="test_profiler.log") + assert os.environ.get("EZMSG_PROFILER") == "test_profiler.log" + + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + + assert len(results) == num_messages @pytest.mark.parametrize("passthrough_settings", [False, True]) @pytest.mark.parametrize("num_messages", [1, 5, 10]) def test_run_comps_conns(passthrough_settings, num_messages): - test_filename = get_test_fn() - if passthrough_settings: - comps = { - "SIMPLE_PUB": MessageGenerator(num_msgs=num_messages), - "SIMPLE_SUB": MessageReceiver( - num_msgs=num_messages, output_fn=test_filename - ), - } - else: - comps = { - "SIMPLE_PUB": MessageGenerator( - MessageGeneratorSettings(num_msgs=num_messages) - ), - "SIMPLE_SUB": MessageReceiver( - MessageReceiverSettings(num_msgs=num_messages, output_fn=test_filename) - ), - } - conns = ((comps["SIMPLE_PUB"].OUTPUT, comps["SIMPLE_SUB"].INPUT),) - - ez.run(components=comps, connections=conns) - - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) - assert len(results) == num_messages + with get_test_fn() as test_filename: + if passthrough_settings: + comps = { + "SIMPLE_PUB": MessageGenerator(num_msgs=num_messages), + "SIMPLE_SUB": MessageReceiver( + num_msgs=num_messages, output_fn=test_filename + ), + } + else: + comps = { + "SIMPLE_PUB": MessageGenerator( + MessageGeneratorSettings(num_msgs=num_messages) + ), + "SIMPLE_SUB": MessageReceiver( + MessageReceiverSettings(num_msgs=num_messages, output_fn=test_filename) + ), + } + conns = ((comps["SIMPLE_PUB"].OUTPUT, comps["SIMPLE_SUB"].INPUT),) + + ez.run(components=comps, connections=conns) + + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + + assert len(results) == num_messages @pytest.mark.parametrize("passthrough_settings", [False, True]) @pytest.mark.parametrize("num_messages", [1, 5, 10]) def test_run_collection(passthrough_settings, num_messages): - test_filename = get_test_fn() - if passthrough_settings: - collection = ToySystem(num_msgs=num_messages, output_fn=test_filename) - else: - collection = ToySystem( - ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) - ) - ez.run(collection) - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) - assert len(results) == num_messages + with get_test_fn() as test_filename: + if passthrough_settings: + collection = ToySystem(num_msgs=num_messages, output_fn=test_filename) + else: + collection = ToySystem( + ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) + ) + ez.run(collection) + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + assert len(results) == num_messages