From 11d48a462afbc2c9144ff770c485afa969fb4c9c Mon Sep 17 00:00:00 2001 From: Preston Peranich Date: Tue, 13 Jan 2026 21:03:13 -0500 Subject: [PATCH] feat: Add GraphRunner API to replace legacy run method. --- src/ezmsg/core/__init__.py | 4 +- src/ezmsg/core/backend.py | 389 +++++++++++++++++++++++++++---------- 2 files changed, 289 insertions(+), 104 deletions(-) diff --git a/src/ezmsg/core/__init__.py b/src/ezmsg/core/__init__.py index fc36153..f377cb3 100644 --- a/src/ezmsg/core/__init__.py +++ b/src/ezmsg/core/__init__.py @@ -19,6 +19,8 @@ "Unit", "State", "run", + "GraphRunner", + "GraphRunnerStartError", "Complete", "NormalTermination", "GraphServer", @@ -39,7 +41,7 @@ from .collection import Collection, NetworkDefinition from .unit import Unit, task, publisher, subscriber, main, timeit, process, thread from .stream import InputStream, OutputStream -from .backend import run +from .backend import run, GraphRunner, GraphRunnerStartError from .backendprocess import Complete, NormalTermination from .graphserver import GraphServer from .graphcontext import GraphContext diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 09b35bd..49d287d 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -4,12 +4,10 @@ import enum import logging import os - +from threading import BrokenBarrierError from multiprocessing import Event, Barrier from multiprocessing.synchronize import Event as EventType from multiprocessing.synchronize import Barrier as BarrierType -from multiprocessing.connection import wait, Connection -from socket import socket from .netprotocol import DEFAULT_SHM_SIZE, AddressType @@ -43,13 +41,16 @@ def __init__( self, process_units: list[list[Unit]], connections: list[tuple[str, str]] = [], + start_participant: bool = False, ) -> None: self.connections = connections self._process_units = process_units self._processes = None self.term_ev = Event() - self.start_barrier = Barrier(len(process_units)) + self.start_barrier = Barrier( + len(process_units) + (1 if start_participant else 0) + ) self.stop_barrier = Barrier(len(process_units)) def create_processes( @@ -75,6 +76,10 @@ def processes(self) -> list[BackendProcess]: else: return self._processes + @property + def process_count(self) -> int: + return len(self._process_units) + @classmethod def setup( cls, @@ -83,6 +88,7 @@ def setup( connections: NetworkDefinition | None = None, process_components: AbstractCollection[Component] | None = None, force_single_process: bool = False, + start_participant: bool = False, ) -> "ExecutionContext | None": graph_connections: list[tuple[str, str]] = [] @@ -148,9 +154,273 @@ def configure_collections(comp: Component): return cls( processes, graph_connections, + start_participant, ) +class GraphRunnerStartError(RuntimeError): + pass + + +class GraphRunner: + _components: Mapping[str, Component] + _execution_context: ExecutionContext | None + _graph_context: GraphContext | None + _loop: asyncio.AbstractEventLoop | None + _loop_cm: object | None + _main_process: BackendProcess | None + _spawned_processes: list[BackendProcess] + _start_participant: bool + _cleanup_done: bool + _graph_server_spawned: bool + _started: bool + _stopped: bool + + def __init__( + self, + components: Mapping[str, Component] | None = None, + root_name: str | None = None, + connections: NetworkDefinition | None = None, + process_components: AbstractCollection[Component] | None = None, + backend_process: type[BackendProcess] = DefaultBackendProcess, + graph_address: AddressType | None = None, + force_single_process: bool = False, + profiler_log_name: str | None = None, + **components_kwargs: Component, + ) -> None: + if components is not None and isinstance(components, Component): + components = {"SYSTEM": components} + logger.warning( + "Passing a single Component without naming the Component is now Deprecated." + ) + components = either_dict_or_kwargs(components, components_kwargs, "run") + if components is None: + raise ValueError("Must supply at least one component to run") + + self._components = components + self._root_name = root_name + self._connections = connections + self._process_components = process_components + self._backend_process = backend_process + self._graph_address = graph_address + self._force_single_process = force_single_process + self._profiler_log_name = profiler_log_name + + self._execution_context = None + self._graph_context = None + self._loop = None + self._loop_cm = None + self._main_process = None + self._spawned_processes = [] + self._start_participant = False + self._cleanup_done = False + self._graph_server_spawned = False + self._started = False + self._stopped = False + + @property + def graph_address(self) -> AddressType | None: + if self._graph_context is not None: + return self._graph_context.graph_address + return self._graph_address + + @property + def graph_server_spawned(self) -> bool: + return self._graph_server_spawned + + @property + def connections(self) -> list[tuple[str, str]]: + if self._execution_context is None: + return [] + return list(self._execution_context.connections) + + @property + def processes(self) -> list[BackendProcess]: + if self._execution_context is None: + raise ValueError("GraphRunner has not initialized processes") + return self._execution_context.processes + + @property + def running(self) -> bool: + return self._started + + def start(self) -> None: + if self._started: + raise RuntimeError("GraphRunner is already running") + if self._stopped: + raise RuntimeError("GraphRunner cannot be restarted") + if self._force_single_process: + raise ValueError("force_single_process is only supported with run_blocking") + if not self._initialize(force_single_process=False, wait_for_ready=True): + return + + self._start_processes(self.processes) + + if self._start_participant and self._execution_context is not None: + try: + self._execution_context.start_barrier.wait() + except BrokenBarrierError as err: + self._execution_context.term_ev.set() + self._join_spawned_processes() + self._cleanup() + self._stopped = True + raise GraphRunnerStartError( + "GraphRunner failed to start. One or more processes exited before " + "reaching the start barrier; check logs for earlier exceptions." + ) from err + self._started = True + if self._stopped: + self._started = False + + def stop(self) -> None: + if not self._started: + raise RuntimeError("GraphRunner is not running") + if self._execution_context is None: + raise RuntimeError("GraphRunner execution context is invalid!") + self._execution_context.term_ev.set() + self._join_spawned_processes() + self._cleanup() + self._started = False + self._stopped = True + + def run_blocking(self) -> None: + if self._started: + raise RuntimeError("GraphRunner is already running") + if self._stopped: + raise RuntimeError("GraphRunner cannot be restarted") + if not self._initialize( + force_single_process=self._force_single_process, wait_for_ready=False + ): + return + self._started = True + self._run_main_process() + + def _initialize(self, force_single_process: bool, wait_for_ready: bool) -> bool: + os.environ["EZMSG_PROFILER"] = self._profiler_log_name or "ezprofiler.log" + self._cleanup_done = False + self._spawned_processes = [] + self._start_participant = wait_for_ready + + self._execution_context = ExecutionContext.setup( + self._components, + self._root_name, + self._connections, + self._process_components, + force_single_process, + wait_for_ready, + ) + + if self._execution_context is None: + return False + + self._loop_cm = new_threaded_event_loop() + self._loop = self._loop_cm.__enter__() + + try: + + async def create_graph_context() -> GraphContext: + return await GraphContext(self._graph_address).__aenter__() + + graph_context = asyncio.run_coroutine_threadsafe( + create_graph_context(), self._loop + ).result() + self._graph_context = graph_context + self._graph_server_spawned = graph_context._graph_server is not None + + if graph_context._graph_server is None: + address = graph_context.graph_address + if address is None: + address = GraphService.default_address() + logger.info(f"Connected to GraphServer @ {address}") + else: + logger.info(f"Spawned GraphServer @ {graph_context.graph_address}") + + self._execution_context.create_processes( + graph_address=graph_context.graph_address, + backend_process=self._backend_process, + ) + + async def setup_graph() -> None: + for edge in self._execution_context.connections: + await graph_context.connect(*edge) + + asyncio.run_coroutine_threadsafe(setup_graph(), self._loop).result() + + if len(self._execution_context.processes) > 1: + logger.info( + f"Running in {len(self._execution_context.processes)} processes." + ) + + except Exception: + self._cleanup() + raise + + return True + + def _start_processes(self, processes: list[BackendProcess]) -> None: + for proc in processes: + proc.start() + self._spawned_processes.append(proc) + + def _join_spawned_processes(self) -> None: + for proc in self._spawned_processes: + proc.join() + + def _run_main_process(self) -> None: + if self._execution_context is None or self._loop is None: + return + self._main_process = self._execution_context.processes[0] + self._start_processes(self._execution_context.processes[1:]) + + try: + self._main_process.process(self._loop) + self._join_spawned_processes() + logger.info("All processes exited normally") + + except KeyboardInterrupt: + logger.info( + "Attempting graceful shutdown, interrupt again to force quit..." + ) + self._execution_context.term_ev.set() + + try: + self._join_spawned_processes() + + except KeyboardInterrupt: + logger.warning("Interrupt intercepted, force quitting") + self._execution_context.start_barrier.abort() + self._execution_context.stop_barrier.abort() + for proc in self._spawned_processes: + proc.terminate() + + finally: + self._join_spawned_processes() + self._cleanup() + self._started = False + self._stopped = True + + def _cleanup(self) -> None: + if self._cleanup_done: + return + self._cleanup_done = True + + if self._graph_context is not None and self._loop is not None: + + async def cleanup_graph() -> None: + await self._graph_context.__aexit__(None, None, None) + + asyncio.run_coroutine_threadsafe(cleanup_graph(), self._loop).result() + + if self._loop_cm is not None: + self._loop_cm.__exit__(None, None, None) + + self._loop_cm = None + self._loop = None + self._graph_context = None + self._spawned_processes = [] + self._start_participant = False + + def run_system( system: Collection, num_buffers: int = 32, @@ -223,105 +493,18 @@ def run( .. note:: The old method :obj:`run_system` has been deprecated and uses ``run()`` instead. """ - os.environ["EZMSG_PROFILER"] = profiler_log_name or "ezprofiler.log" - # FIXME: This function is the last major re-implementation needed to make this - # codebase more maintainable. - - if components is not None and isinstance(components, Component): - components = {"SYSTEM": components} - logger.warning( - "Passing a single Component without naming the Component is now Deprecated." - ) - components = either_dict_or_kwargs(components, components_kwargs, "run") - if components is None: - raise ValueError("Must supply at least one component to run") - - with new_threaded_event_loop() as loop: - execution_context = ExecutionContext.setup( - components, - root_name, - connections, - process_components, - force_single_process, - ) - - if execution_context is None: - return - - # FIXME: When done this way, we don't exit the graph_context on exception - async def create_graph_context() -> GraphContext: - return await GraphContext(graph_address).__aenter__() - - # FIXME: This sort of stuff should all be done in a separate async function... - # Done this way, its ugly as hell and opens us up to a lot of issues with - # entering and exiting context properly on exceptions. - graph_context = asyncio.run_coroutine_threadsafe( - create_graph_context(), loop - ).result() - - if graph_context._graph_server is None: - address = graph_context.graph_address - if address is None: - address = GraphService.default_address() - logger.info(f"Connected to GraphServer @ {address}") - else: - logger.info(f"Spawned GraphServer @ {graph_context.graph_address}") - - execution_context.create_processes( - graph_address=graph_context.graph_address, backend_process=backend_process - ) - - async def cleanup_graph() -> None: - await graph_context.__aexit__(None, None, None) - - async def setup_graph() -> None: - for edge in execution_context.connections: - await graph_context.connect(*edge) - - asyncio.run_coroutine_threadsafe(setup_graph(), loop).result() - - if len(execution_context.processes) > 1: - logger.info(f"Running in {len(execution_context.processes)} processes.") - - main_process = execution_context.processes[0] - other_processes = execution_context.processes[1:] - - sentinels: set[Connection | socket | int] = set() - - for proc in other_processes: - proc.start() - sentinels.add(proc.sentinel) - - def join_all_other_processes(): - while len(sentinels): - done = wait(sentinels, timeout=0.1) - for sentinel in done: - sentinels.discard(sentinel) - - try: - main_process.process(loop) - join_all_other_processes() - logger.info("All processes exited normally") - - except KeyboardInterrupt: - logger.info( - "Attempting graceful shutdown, interrupt again to force quit..." - ) - execution_context.term_ev.set() - - try: - join_all_other_processes() - - except KeyboardInterrupt: - logger.warning("Interrupt intercepted, force quitting") - execution_context.start_barrier.abort() - execution_context.stop_barrier.abort() - for proc in other_processes: - proc.terminate() - - finally: - join_all_other_processes() - asyncio.run_coroutine_threadsafe(cleanup_graph(), loop).result() + runner = GraphRunner( + components=components, + root_name=root_name, + connections=connections, + process_components=process_components, + backend_process=backend_process, + graph_address=graph_address, + force_single_process=force_single_process, + profiler_log_name=profiler_log_name, + **components_kwargs, + ) + runner.run_blocking() def collect_processes(