From 761543ff34995d88b8f3400fde1a7d7751288ce8 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 19 Sep 2025 16:30:00 -0400 Subject: [PATCH 01/88] much more comprehensive perf test with analysis: ezmsg-perf --- pyproject.toml | 4 +- src/ezmsg/util/perf/__init__.py | 0 src/ezmsg/util/perf/analysis.py | 113 ++++++++++ src/ezmsg/util/perf/command.py | 55 +++++ src/ezmsg/util/perf/eval.py | 91 ++++++++ src/ezmsg/util/perf/util.py | 356 ++++++++++++++++++++++++++++++++ src/ezmsg/util/perf_test.py | 220 -------------------- 7 files changed, 618 insertions(+), 221 deletions(-) create mode 100644 src/ezmsg/util/perf/__init__.py create mode 100644 src/ezmsg/util/perf/analysis.py create mode 100644 src/ezmsg/util/perf/command.py create mode 100644 src/ezmsg/util/perf/eval.py create mode 100644 src/ezmsg/util/perf/util.py delete mode 100644 src/ezmsg/util/perf_test.py diff --git a/pyproject.toml b/pyproject.toml index df1f469a..961271b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ test = [ "pytest>=8.4.1", "pytest-asyncio>=1.1.0", "pytest-cov>=6.2.1", - "xarray>=2023.1.0;python_version<'3.13'" + "xarray>=2023.1.0;python_version<'3.13'", + "psutil>=7.1.0", ] docs = [ "ezmsg-sigproc>=2.2.0", @@ -44,6 +45,7 @@ docs = [ [project.scripts] ezmsg = "ezmsg.core.command:cmdline" +ezmsg-perf = "ezmsg.util.perf.command:command" [project.optional-dependencies] axisarray = [ diff --git a/src/ezmsg/util/perf/__init__.py b/src/ezmsg/util/perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py new file mode 100644 index 00000000..48dd800e --- /dev/null +++ b/src/ezmsg/util/perf/analysis.py @@ -0,0 +1,113 @@ +import json +import typing +import dataclasses + +from pathlib import Path + +from ..messagecodec import MessageDecoder +from .util import ( + TestEnvironmentInfo, + TestParameters, + Metrics, +) + +import ezmsg.core as ez + +try: + import xarray as xr +except ImportError: + ez.logger.error('ezmsg perf analysis requires xarray') + raise + +try: + import numpy as np +except ImportError: + ez.logger.error('ezmsg perf analysis requires numpy') + raise + + +def load_perf(perf: Path) -> xr.Dataset: + + params: typing.List[TestParameters] = [] + results: typing.List[Metrics] = [] + + with open(perf, 'r') as perf_f: + info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) + + for line in perf_f: + obj = json.loads(line, cls = MessageDecoder) + params.append(obj['params']) + results.append(obj['results']) + + n_clients_axis = list(sorted(set([p.n_clients for p in params]))) + msg_size_axis = list(sorted(set([p.msg_size for p in params]))) + comms_axis = list(sorted(set([p.comms for p in params]))) + config_axis = list(sorted(set([p.config for p in params]))) + + dims = ['n_clients', 'msg_size', 'comms', 'config'] + coords = { + 'n_clients': n_clients_axis, + 'msg_size': msg_size_axis, + 'comms': comms_axis, + 'config': config_axis + } + + data_vars = {} + for field in dataclasses.fields(Metrics): + m = np.zeros(( + len(n_clients_axis), + len(msg_size_axis), + len(comms_axis), + len(config_axis) + )) * np.nan + for p, r in zip(params, results): + m[ + n_clients_axis.index(p.n_clients), + msg_size_axis.index(p.msg_size), + comms_axis.index(p.comms), + config_axis.index(p.config) + ] = getattr(r, field.name) + data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) + + dataset = xr.Dataset(data_vars, attrs = dict(info = info)) + return dataset + +@dataclasses.dataclass +class SummaryArgs: + perf: Path + baseline: Path | None + +def summary(args: SummaryArgs): + + perf_dataset = load_perf(args.perf) + + for config, config_ds in perf_dataset.groupby('config'): + for comms, comms_ds in config_ds.groupby('comms'): + print(f'{config}: {comms}') + print(comms_ds.sample_rate.data) + + if args.baseline is not None: + baseline_dataset = load_perf(args.baseline) + + +def command() -> None: + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "perf", + type=lambda x: Path(x), + help="perf test", + ) + + parser.add_argument( + "--baseline", "-b", + type=lambda x: Path(x), + default = None, + help="baseline perf test for comparison" + ) + + args = parser.parse_args(namespace=SummaryArgs) + + summary(args) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py new file mode 100644 index 00000000..5269a1d0 --- /dev/null +++ b/src/ezmsg/util/perf/command.py @@ -0,0 +1,55 @@ +from pathlib import Path + +from .analysis import summary, SummaryArgs +from .eval import perf_eval, PerfEvalArgs + +def command() -> None: + import argparse + + parser = argparse.ArgumentParser(description = 'ezmsg perf test utility') + subparsers = parser.add_subparsers(dest="command", required=True) + + p_run = subparsers.add_parser("run", help="run performance test") + p_run.add_argument( + "--duration", + type=float, + default=2.0, + help="individual test duration in seconds (default = 2.0)", + ) + p_run.add_argument( + "--num-buffers", + type=int, + default=32, + help="shared memory buffers (default = 32)", + ) + + p_run.set_defaults(_handler=lambda ns: perf_eval( + PerfEvalArgs( + duration = ns.duration, + num_buffers = ns.num_buffers + ) + )) + + p_summary = subparsers.add_parser("summary", help = "summarise performance results") + p_summary.add_argument( + "perf", + type=Path, + help="perf test", + ) + p_summary.add_argument( + "--baseline", + "-b", + type=Path, + default=None, + help="baseline perf test for comparison", + ) + + p_summary.set_defaults(_handler=lambda ns: summary( + SummaryArgs( + perf = ns.perf, + baseline = ns.baseline + ) + )) + + ns = parser.parse_args() + ns._handler(ns) diff --git a/src/ezmsg/util/perf/eval.py b/src/ezmsg/util/perf/eval.py new file mode 100644 index 00000000..894c45db --- /dev/null +++ b/src/ezmsg/util/perf/eval.py @@ -0,0 +1,91 @@ +import json +import datetime +import itertools + +from dataclasses import dataclass + +from ..messagecodec import MessageEncoder +from .util import ( + TestEnvironmentInfo, + TestParameters, + perform_test, + Communication, + CONFIGS, +) + +import ezmsg.core as ez + +def get_datestamp() -> str: + return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + +@dataclass +class PerfEvalArgs: + duration: float + num_buffers: int + +def perf_eval(args: PerfEvalArgs) -> None: + + msg_sizes = [2 ** exp for exp in range(4, 25, 4)] + n_clients = [2 ** exp for exp in range(0, 6)] + comms = [c for c in Communication] + + test_list = list(itertools.product(msg_sizes, n_clients, CONFIGS, comms)) + + with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: + + out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + + for test_idx, (msg_size, n_clients, config, comms) in enumerate(test_list): + + ez.logger.info(f"RUNNING TEST {test_idx + 1} / {len(test_list)} ({(test_idx / len(test_list)) * 100.0:0.2f} %)") + + params = TestParameters( + msg_size = msg_size, + n_clients = n_clients, + config = config.__name__, + comms = comms.value, + duration = args.duration, + num_buffers = args.num_buffers + ) + + results = perform_test( + n_clients = n_clients, + duration = args.duration, + msg_size = msg_size, + buffers = args.num_buffers, + comms = comms, + config = config, + ) + + output = dict( + params = params, + results = results + ) + + out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + + +def command() -> None: + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--duration", + type=float, + default=2.0, + help="How long to run each load test (seconds) (default = 2.0)", + ) + + parser.add_argument( + "--num-buffers", + type=int, + default=32, + help="shared memory buffers (default = 32)" + ) + + args = parser.parse_args(namespace=PerfEvalArgs) + + perf_eval(args) + + diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py new file mode 100644 index 00000000..8883591b --- /dev/null +++ b/src/ezmsg/util/perf/util.py @@ -0,0 +1,356 @@ +import asyncio +import dataclasses +import datetime +import os +import platform +import time +import typing +import enum +import sys +import subprocess + +import ezmsg.core as ez + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +try: + import psutil +except ImportError: + ez.logger.error("ezmsg perf requires psutil") + raise + + +def collect( + components: typing.Optional[typing.Mapping[str, ez.Component]] = None, + network: ez.NetworkDefinition = (), + process_components: typing.Collection[ez.Component] | None = None, + **components_kwargs: ez.Component, +) -> ez.Collection: + """ collect a grouping of pre-configured components into a new "Collection" """ + from ezmsg.core.util import either_dict_or_kwargs + + components = either_dict_or_kwargs(components, components_kwargs, "collect") + if components is None: + raise ValueError("Must supply at least one component to run") + + out = ez.Collection() + for name, comp in components.items(): + comp._set_name(name) + out._components = components # FIXME: Component._components should be typehinted as a Mapping + out.network = lambda: network + out.process_components = lambda: (out,) if process_components is None else process_components + return out + + +@dataclasses.dataclass +class Metrics: + num_msgs: int + sample_rate: float + latency_mean: float + latency_total: float + data_rate: float + + +class LoadTestSettings(ez.Settings): + duration: float + dynamic_size: int + buffers: int + force_tcp: bool + + +@dataclasses.dataclass +class LoadTestSample: + _timestamp: float + counter: int + dynamic_data: np.ndarray + key: str + + +class LoadTestSender(ez.Unit): + OUTPUT = ez.OutputStream(LoadTestSample) + SETTINGS = LoadTestSettings + + async def initialize(self) -> None: + self.running = True + self.counter = 0 + + self.OUTPUT.num_buffers = self.SETTINGS.buffers + self.OUTPUT.force_tcp = self.SETTINGS.force_tcp + + @ez.publisher(OUTPUT) + async def publish(self) -> typing.AsyncGenerator: + ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") + start_time = time.perf_counter() + while self.running: + current_time = time.perf_counter() + if current_time - start_time >= self.SETTINGS.duration: + break + + yield ( + self.OUTPUT, + LoadTestSample( + _timestamp=time.perf_counter(), + counter=self.counter, + dynamic_data=np.zeros( + int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 + ), + key = self.name, + ), + ) + self.counter += 1 + ez.logger.info("Exiting publish") + raise ez.Complete + +class LoadTestSource(LoadTestSender): + async def shutdown(self) -> None: + self.running = False + ez.logger.info(f"Samples sent: {self.counter}") + + +class LoadTestRelay(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + OUTPUT = ez.OutputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy = True) + @ez.publisher(OUTPUT) + async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator: + yield self.OUTPUT, msg + + +class LoadTestReceiverState(ez.State): + # Tuples of sent timestamp, received timestamp, counter, dynamic size + received_data: typing.List[typing.Tuple[float, float, int]] = dataclasses.field( + default_factory=list + ) + counters: typing.Dict[str, int] = dataclasses.field(default_factory=dict) + + +class LoadTestReceiver(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + SETTINGS = LoadTestSettings + STATE = LoadTestReceiverState + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + counter = self.STATE.counters.get(sample.key, -1) + if sample.counter != counter + 1: + ez.logger.warning( + f"{sample.counter - counter - 1} samples skipped!" + ) + self.STATE.received_data.append( + (sample._timestamp, time.perf_counter(), sample.counter) + ) + self.STATE.counters[sample.key] = sample.counter + + +class LoadTestSink(LoadTestReceiver): + + @ez.task + async def terminate(self) -> None: + ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") + + # Wait for the duration of the load test + await asyncio.sleep(self.SETTINGS.duration) + raise ez.NormalTermination + + +### TEST CONFIGURATIONS + +@dataclasses.dataclass +class ConfigSettings: + n_clients: int + settings: LoadTestSettings + source: LoadTestSource + sink: LoadTestSink + +Configuration = typing.Tuple[typing.Iterable[ez.Component], ez.NetworkDefinition] +Configurator = typing.Callable[[ConfigSettings], Configuration] + +def fanout(config: ConfigSettings) -> Configuration: + """ one pub to many subs """ + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + subs = [LoadTestReceiver(config.settings) for _ in range(config.n_clients)] + for sub in subs: + connections.append((config.source.OUTPUT, sub.INPUT)) + + return subs, connections + +def fanin(config: ConfigSettings) -> Configuration: + """ many pubs to one sub """ + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + pubs = [LoadTestSender(config.settings) for _ in range(config.n_clients)] + for pub in pubs: + connections.append((pub.OUTPUT, config.sink.INPUT)) + return pubs, connections + + +def relay(config: ConfigSettings) -> Configuration: + """ one pub to one sub through many relays """ + connections: ez.NetworkDefinition = [] + + relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] + connections.append((config.source.OUTPUT, relays[0].INPUT)) + for from_relay, to_relay in zip(relays[:-1], relays[1:]): + connections.append((from_relay.OUTPUT, to_relay.INPUT)) + connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + + return relays, connections + +CONFIGS: typing.Iterable[Configurator] = [fanin, fanout, relay] + +class Communication(enum.StrEnum): + LOCAL = "local" + SHM = "shm" + SHM_SPREAD = "shm_spread" + TCP = "tcp" + TCP_SPREAD = "tcp_spread" + +def perform_test( + n_clients: int, + duration: float, + msg_size: int, + buffers: int, + comms: Communication, + config: Configurator +) -> Metrics: + + settings = LoadTestSettings( + dynamic_size = int(msg_size), + duration = duration, + buffers = buffers, + force_tcp = (comms == Communication.TCP), + ) + + source = LoadTestSource(settings) + sink = LoadTestSink(settings) + + components: typing.Mapping[str, ez.Component] = dict( + SINK = sink, + ) + + clients, connections = config(ConfigSettings(n_clients, settings, source, sink)) + + # The 'sink' MUST remain in this process for us to pull its state. + process_components: typing.Iterable[ez.Component] = [] + if comms == Communication.LOCAL: + # Every component in the same process (this one) + components["SOURCE"] = source + for i, client in enumerate(clients): + components[f"CLIENT_{i+1}"] = client + + else: + + if comms in (Communication.SHM_SPREAD, Communication.TCP_SPREAD): + # Every component in its own process. + components["SOURCE"] = source + process_components.append(source) + for i, client in enumerate(clients): + components[f'CLIENT_{i+1}'] = client + process_components.append(client) + + else: + # All clients and the source in ONE other process. + collect_comps: typing.Mapping[str, ez.Component] = dict() + collect_comps["SOURCE"] = source + for i, client in enumerate(clients): + collect_comps[f"CLIENT_{i+1}"] = client + proc_collection = collect(components = collect_comps) + components["PROC"] = proc_collection + process_components = [proc_collection] + + ez.run( + components = components, + connections = connections, + process_components = process_components, + ) + + return calculate_metrics(sink) + + +def calculate_metrics(sink: LoadTestSink) -> Metrics: + + # Log some useful summary statistics + min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) + max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) + total_latency = abs( + sum( + receive_timestamp - send_timestamp + for send_timestamp, receive_timestamp, _ in sink.STATE.received_data + ) + ) + + counters = list(sorted(t[2] for t in sink.STATE.received_data)) + dropped_samples = sum( + [max((x1 - x0) - 1, 0) for x1, x0 in zip(counters[1:], counters[:-1])] + ) + + num_samples = len(sink.STATE.received_data) + ez.logger.info(f"Samples received: {num_samples}") + sample_rate = num_samples / (max_timestamp - min_timestamp) + ez.logger.info(f"Sample rate: {sample_rate} Hz") + latency_mean = total_latency / num_samples + ez.logger.info(f"Mean latency: {latency_mean} s") + ez.logger.info(f"Total latency: {total_latency} s") + + total_data = num_samples * sink.SETTINGS.dynamic_size + data_rate = total_data / (max_timestamp - min_timestamp) + ez.logger.info(f"Data rate: {data_rate * 1e-6} MB/s") + + if dropped_samples: + ez.logger.error( + f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", + ) + + return Metrics( + num_msgs = num_samples, + sample_rate = sample_rate, + latency_mean = latency_mean, + latency_total = total_latency, + data_rate = data_rate + ) + + +@dataclasses.dataclass +class TestParameters: + msg_size: int + n_clients: int + config: str + comms: str + duration: float + num_buffers: int + +def _git_commit() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +def _git_branch() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +@dataclasses.dataclass +class TestEnvironmentInfo: + ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) + numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) + python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) + os: str = dataclasses.field(default_factory=lambda: platform.system()) + os_version: str = dataclasses.field(default_factory=lambda: platform.version()) + machine: str = dataclasses.field(default_factory=lambda: platform.machine()) + processor: str = dataclasses.field(default_factory=lambda: platform.processor()) + cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) + cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) + memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) + start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) + git_commit: str = dataclasses.field(default_factory=_git_commit) + git_branch: str = dataclasses.field(default_factory=_git_branch) \ No newline at end of file diff --git a/src/ezmsg/util/perf_test.py b/src/ezmsg/util/perf_test.py deleted file mode 100644 index b7536b5e..00000000 --- a/src/ezmsg/util/perf_test.py +++ /dev/null @@ -1,220 +0,0 @@ -import asyncio -import dataclasses -import datetime -import os -import platform -import time - -from typing import List, Tuple, AsyncGenerator - -import ezmsg.core as ez - -# We expect this test to generate LOTS of backpressure warnings -# PERF_LOGLEVEL = os.environ.get("EZMSG_LOGLEVEL", "ERROR") -# ez.logger.setLevel(PERF_LOGLEVEL) - -PLATFORM = { - "Darwin": "mac", - "Linux": "linux", - "Windows": "win", -}[platform.system()] -SAMPLE_SUMMARY_DATASET_PREFIX = "sample_summary" -COUNT_DATASET_NAME = "count" - - -try: - import numpy as np -except ImportError: - ez.logger.error("This test requires Numpy to run.") - raise - - -def get_datestamp() -> str: - return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - -class LoadTestSettings(ez.Settings): - duration: float = 30.0 - dynamic_size: int = 8 - buffers: int = 32 - - -@dataclasses.dataclass -class LoadTestSample: - _timestamp: float - counter: int - dynamic_data: np.ndarray - - -class LoadTestPublisher(ez.Unit): - OUTPUT = ez.OutputStream(LoadTestSample) - SETTINGS = LoadTestSettings - - async def initialize(self) -> None: - self.running = True - self.counter = 0 - self.OUTPUT.num_buffers = self.SETTINGS.buffers - - @ez.publisher(OUTPUT) - async def publish(self) -> AsyncGenerator: - ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.time() - while self.running: - current_time = time.time() - if current_time - start_time >= self.SETTINGS.duration: - break - - yield ( - self.OUTPUT, - LoadTestSample( - _timestamp=time.time(), - counter=self.counter, - dynamic_data=np.zeros( - int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 - ), - ), - ) - self.counter += 1 - ez.logger.info("Exiting publish") - raise ez.Complete - - async def shutdown(self) -> None: - self.running = False - ez.logger.info(f"Samples sent: {self.counter}") - - -class LoadTestSubscriberState(ez.State): - # Tuples of sent timestamp, received timestamp, counter, dynamic size - received_data: List[Tuple[float, float, int]] = dataclasses.field( - default_factory=list - ) - counter: int = -1 - - -class LoadTestSubscriber(ez.Unit): - INPUT = ez.InputStream(LoadTestSample) - SETTINGS = LoadTestSettings - STATE = LoadTestSubscriberState - - @ez.subscriber(INPUT, zero_copy=True) - async def receive(self, sample: LoadTestSample) -> None: - if sample.counter != self.STATE.counter + 1: - ez.logger.warning( - f"{sample.counter - self.STATE.counter - 1} samples skipped!" - ) - self.STATE.received_data.append( - (sample._timestamp, time.time(), sample.counter) - ) - self.STATE.counter = sample.counter - - @ez.task - async def log_result(self) -> None: - ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") - - # Wait for the duration of the load test - await asyncio.sleep(self.SETTINGS.duration) - # logger.info(f"STATE = {self.STATE.received_data}") - - # Log some useful summary statistics - min_timestamp = min(timestamp for timestamp, _, _ in self.STATE.received_data) - max_timestamp = max(timestamp for timestamp, _, _ in self.STATE.received_data) - total_latency = abs( - sum( - receive_timestamp - send_timestamp - for send_timestamp, receive_timestamp, _ in self.STATE.received_data - ) - ) - - counters = list(sorted(t[2] for t in self.STATE.received_data)) - dropped_samples = sum( - [(x1 - x0) - 1 for x1, x0 in zip(counters[1:], counters[:-1])] - ) - - num_samples = len(self.STATE.received_data) - ez.logger.info(f"Samples received: {num_samples}") - ez.logger.info( - f"Sample rate: {num_samples / (max_timestamp - min_timestamp)} Hz" - ) - ez.logger.info(f"Mean latency: {total_latency / num_samples} s") - ez.logger.info(f"Total latency: {total_latency} s") - - total_data = num_samples * self.SETTINGS.dynamic_size - ez.logger.info( - f"Data rate: {total_data / (max_timestamp - min_timestamp) * 1e-6} MB/s" - ) - ez.logger.info( - f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", - ) - - raise ez.NormalTermination - - -class LoadTest(ez.Collection): - SETTINGS = LoadTestSettings - - PUBLISHER = LoadTestPublisher() - SUBSCRIBER = LoadTestSubscriber() - - def configure(self) -> None: - self.PUBLISHER.apply_settings(self.SETTINGS) - self.SUBSCRIBER.apply_settings(self.SETTINGS) - - def network(self) -> ez.NetworkDefinition: - return ((self.PUBLISHER.OUTPUT, self.SUBSCRIBER.INPUT),) - - def process_components(self): - return ( - self.PUBLISHER, - self.SUBSCRIBER, - ) - - -def get_time() -> float: - # time.perf_counter() isn't system-wide on Windows Python 3.6: - # https://bugs.python.org/issue37205 - return time.time() if PLATFORM == "win" else time.perf_counter() - - -def test_performance(duration, size, buffers) -> None: - ez.logger.info(f"Running load test for dynamic size: {size} bytes") - system = LoadTest( - LoadTestSettings(dynamic_size=int(size), duration=duration, buffers=buffers) - ) - ez.run(SYSTEM=system) - - -def run_many_dynamic_sizes(duration, buffers) -> None: - for exp in range(5, 22, 4): - test_performance(duration, 2**exp, buffers) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--many-dynamic-sizes", - action="store_true", - help="Run load test for many dynamic sizes", - ) - parser.add_argument( - "--duration", - type=int, - default=2, - help="How long to run the load test (seconds)", - ) - parser.add_argument( - "--num-buffers", type=int, default=32, help="Shared memory buffers" - ) - - class Args: - many_dynamic_sizes: bool - duration: int - num_buffers: int - - args = parser.parse_args(namespace=Args) - - if args.many_dynamic_sizes: - run_many_dynamic_sizes(args.duration, args.num_buffers) - else: - test_performance(args.duration, 8, args.num_buffers) From 0722bc9c8c47544df2377cb8b93e5564136678c2 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 08:50:38 -0400 Subject: [PATCH 02/88] less aggressive test strategy --- src/ezmsg/util/perf/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/util/perf/eval.py b/src/ezmsg/util/perf/eval.py index 894c45db..f7f7e27b 100644 --- a/src/ezmsg/util/perf/eval.py +++ b/src/ezmsg/util/perf/eval.py @@ -25,8 +25,8 @@ class PerfEvalArgs: def perf_eval(args: PerfEvalArgs) -> None: - msg_sizes = [2 ** exp for exp in range(4, 25, 4)] - n_clients = [2 ** exp for exp in range(0, 6)] + msg_sizes = [2 ** exp for exp in range(4, 25, 8)] + n_clients = [2 ** exp for exp in range(0, 6, 2)] comms = [c for c in Communication] test_list = list(itertools.product(msg_sizes, n_clients, CONFIGS, comms)) From d1fc6e3874b9506420e0abd94c6fbd2c9359ede2 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 10:00:16 -0400 Subject: [PATCH 03/88] slight refactor --- src/ezmsg/util/perf/analysis.py | 60 ++++++++--------- src/ezmsg/util/perf/command.py | 53 ++------------- src/ezmsg/util/perf/envinfo.py | 83 ++++++++++++++++++++++++ src/ezmsg/util/perf/{util.py => impl.py} | 41 ------------ src/ezmsg/util/perf/{eval.py => run.py} | 42 ++++++------ 5 files changed, 139 insertions(+), 140 deletions(-) create mode 100644 src/ezmsg/util/perf/envinfo.py rename src/ezmsg/util/perf/{util.py => impl.py} (84%) rename src/ezmsg/util/perf/{eval.py => run.py} (72%) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 48dd800e..0be97744 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -1,12 +1,13 @@ import json import typing import dataclasses +import argparse from pathlib import Path from ..messagecodec import MessageDecoder -from .util import ( - TestEnvironmentInfo, +from .envinfo import TestEnvironmentInfo +from .impl import ( TestParameters, Metrics, ) @@ -72,42 +73,41 @@ def load_perf(perf: Path) -> xr.Dataset: dataset = xr.Dataset(data_vars, attrs = dict(info = info)) return dataset -@dataclasses.dataclass -class SummaryArgs: - perf: Path - baseline: Path | None -def summary(args: SummaryArgs): +def summary(perf_path: Path, baseline_path: Path | None) -> None: + """ print perf test results and comparisons to the console """ - perf_dataset = load_perf(args.perf) + perf = load_perf(perf_path) + info = perf.attrs['info'] + if baseline_path is not None: + baseline = load_perf(baseline_path) + perf = (perf / baseline) * 100.0 - for config, config_ds in perf_dataset.groupby('config'): - for comms, comms_ds in config_ds.groupby('comms'): - print(f'{config}: {comms}') - print(comms_ds.sample_rate.data) + print(info) - if args.baseline is not None: - baseline_dataset = load_perf(args.baseline) + for _, config_ds in perf.groupby('config'): + for _, comms_ds in config_ds.groupby('comms'): + print(comms_ds.squeeze().to_dataframe()) + print("\n") + print("\n") -def command() -> None: - import argparse - - parser = argparse.ArgumentParser() - - parser.add_argument( +def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_summary = subparsers.add_parser("summary", help = "summarize performance results") + p_summary.add_argument( "perf", - type=lambda x: Path(x), + type=Path, help="perf test", ) - - parser.add_argument( - "--baseline", "-b", - type=lambda x: Path(x), - default = None, - help="baseline perf test for comparison" + p_summary.add_argument( + "--baseline", + "-b", + type=Path, + default=None, + help="baseline perf test for comparison", ) - args = parser.parse_args(namespace=SummaryArgs) - - summary(args) + p_summary.set_defaults(_handler=lambda ns: summary( + perf_path = ns.perf, + baseline_path = ns.baseline + )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index 5269a1d0..a100863d 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -1,55 +1,14 @@ -from pathlib import Path +import argparse -from .analysis import summary, SummaryArgs -from .eval import perf_eval, PerfEvalArgs +from .analysis import setup_summary_cmdline +from .run import setup_run_cmdline def command() -> None: - import argparse - parser = argparse.ArgumentParser(description = 'ezmsg perf test utility') subparsers = parser.add_subparsers(dest="command", required=True) - p_run = subparsers.add_parser("run", help="run performance test") - p_run.add_argument( - "--duration", - type=float, - default=2.0, - help="individual test duration in seconds (default = 2.0)", - ) - p_run.add_argument( - "--num-buffers", - type=int, - default=32, - help="shared memory buffers (default = 32)", - ) - - p_run.set_defaults(_handler=lambda ns: perf_eval( - PerfEvalArgs( - duration = ns.duration, - num_buffers = ns.num_buffers - ) - )) - - p_summary = subparsers.add_parser("summary", help = "summarise performance results") - p_summary.add_argument( - "perf", - type=Path, - help="perf test", - ) - p_summary.add_argument( - "--baseline", - "-b", - type=Path, - default=None, - help="baseline perf test for comparison", - ) - - p_summary.set_defaults(_handler=lambda ns: summary( - SummaryArgs( - perf = ns.perf, - baseline = ns.baseline - ) - )) - + setup_run_cmdline(subparsers) + setup_summary_cmdline(subparsers) + ns = parser.parse_args() ns._handler(ns) diff --git a/src/ezmsg/util/perf/envinfo.py b/src/ezmsg/util/perf/envinfo.py new file mode 100644 index 00000000..66fff18b --- /dev/null +++ b/src/ezmsg/util/perf/envinfo.py @@ -0,0 +1,83 @@ +import dataclasses +import datetime +import platform +import typing +import sys +import subprocess + +import ezmsg.core as ez + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +try: + import psutil +except ImportError: + ez.logger.error("ezmsg perf requires psutil") + raise + + +def _git_commit() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +def _git_branch() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +@dataclasses.dataclass +class TestEnvironmentInfo: + ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) + numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) + python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) + os: str = dataclasses.field(default_factory=lambda: platform.system()) + os_version: str = dataclasses.field(default_factory=lambda: platform.version()) + machine: str = dataclasses.field(default_factory=lambda: platform.machine()) + processor: str = dataclasses.field(default_factory=lambda: platform.processor()) + cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) + cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) + memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) + start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) + git_commit: str = dataclasses.field(default_factory=_git_commit) + git_branch: str = dataclasses.field(default_factory=_git_branch) + + def __str__(self) -> str: + fields = dataclasses.asdict(self) + width = max(len(k) for k in fields) + lines = ["TestEnvironmentInfo:"] + for key, value in fields.items(): + lines.append(f" {key.ljust(width)} : {value}") + return "\n".join(lines) + + def diff(self, other: "TestEnvironmentInfo") -> typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]: + """Return a structured diff: {field: (self_value, other_value)} for changed fields.""" + a = dataclasses.asdict(self) + b = dataclasses.asdict(other) + keys = set(a) | set(b) + return {k: (a.get(k), b.get(k)) for k in keys if a.get(k) != b.get(k)} + + +def format_env_diff(diffs: typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]) -> str: + """Pretty-print the structured diff in the same aligned style.""" + if not diffs: + return "No differences." + width = max(len(k) for k in diffs) + lines = ["Differences in TestEnvironmentInfo:"] + for k in sorted(diffs): + left, right = diffs[k] + lines.append(f" {k.ljust(width)} : {left} != {right}") + return "\n".join(lines) + +def diff_envs(a: TestEnvironmentInfo, b: TestEnvironmentInfo) -> str: + return format_env_diff(a.diff(b)) \ No newline at end of file diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/impl.py similarity index 84% rename from src/ezmsg/util/perf/util.py rename to src/ezmsg/util/perf/impl.py index 8883591b..58357b8c 100644 --- a/src/ezmsg/util/perf/util.py +++ b/src/ezmsg/util/perf/impl.py @@ -1,13 +1,9 @@ import asyncio import dataclasses -import datetime import os -import platform import time import typing import enum -import sys -import subprocess import ezmsg.core as ez @@ -17,12 +13,6 @@ ez.logger.error("ezmsg perf requires numpy") raise -try: - import psutil -except ImportError: - ez.logger.error("ezmsg perf requires psutil") - raise - def collect( components: typing.Optional[typing.Mapping[str, ez.Component]] = None, @@ -323,34 +313,3 @@ class TestParameters: duration: float num_buffers: int -def _git_commit() -> str: - try: - return subprocess.check_output( - ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL - ).decode().strip() - except: - return "unknown" - -def _git_branch() -> str: - try: - return subprocess.check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL - ).decode().strip() - except: - return "unknown" - -@dataclasses.dataclass -class TestEnvironmentInfo: - ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) - numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) - python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) - os: str = dataclasses.field(default_factory=lambda: platform.system()) - os_version: str = dataclasses.field(default_factory=lambda: platform.version()) - machine: str = dataclasses.field(default_factory=lambda: platform.machine()) - processor: str = dataclasses.field(default_factory=lambda: platform.processor()) - cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) - cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) - memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) - start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) - git_commit: str = dataclasses.field(default_factory=_git_commit) - git_branch: str = dataclasses.field(default_factory=_git_branch) \ No newline at end of file diff --git a/src/ezmsg/util/perf/eval.py b/src/ezmsg/util/perf/run.py similarity index 72% rename from src/ezmsg/util/perf/eval.py rename to src/ezmsg/util/perf/run.py index f7f7e27b..d36dd02c 100644 --- a/src/ezmsg/util/perf/eval.py +++ b/src/ezmsg/util/perf/run.py @@ -1,12 +1,13 @@ import json import datetime import itertools +import argparse from dataclasses import dataclass from ..messagecodec import MessageEncoder -from .util import ( - TestEnvironmentInfo, +from .envinfo import TestEnvironmentInfo +from .impl import ( TestParameters, perform_test, Communication, @@ -19,11 +20,11 @@ def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") @dataclass -class PerfEvalArgs: +class PerfRunArgs: duration: float num_buffers: int -def perf_eval(args: PerfEvalArgs) -> None: +def perf_run(args: PerfRunArgs) -> None: msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] @@ -64,28 +65,25 @@ def perf_eval(args: PerfEvalArgs) -> None: out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") +def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: -def command() -> None: - import argparse - - parser = argparse.ArgumentParser() - - parser.add_argument( + p_run = subparsers.add_parser("run", help="run performance test") + p_run.add_argument( "--duration", type=float, default=2.0, - help="How long to run each load test (seconds) (default = 2.0)", + help="individual test duration in seconds (default = 2.0)", ) - - parser.add_argument( - "--num-buffers", - type=int, - default=32, - help="shared memory buffers (default = 32)" + p_run.add_argument( + "--num-buffers", + type=int, + default=32, + help="shared memory buffers (default = 32)", ) - args = parser.parse_args(namespace=PerfEvalArgs) - - perf_eval(args) - - + p_run.set_defaults(_handler=lambda ns: perf_run( + PerfRunArgs( + duration = ns.duration, + num_buffers = ns.num_buffers + ) + )) \ No newline at end of file From b8841180299e52bc560a53854cd8b15dbd398fa0 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 13:28:20 -0400 Subject: [PATCH 04/88] minor --- src/ezmsg/util/perf/ai.py | 111 ++++++++++++++++++++++++++++++++ src/ezmsg/util/perf/analysis.py | 38 ++++++++--- 2 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 src/ezmsg/util/perf/ai.py diff --git a/src/ezmsg/util/perf/ai.py b/src/ezmsg/util/perf/ai.py new file mode 100644 index 00000000..a212b81c --- /dev/null +++ b/src/ezmsg/util/perf/ai.py @@ -0,0 +1,111 @@ +import os +import json +import textwrap +import urllib.request + +DEFAULT_TEST_DESCRIPTION = """\ +You are analyzing performance test results for the ezmsg pub/sub system. + +Configurations (config): +- fanin: Many publishers to one subscriber +- fanout: one publisher to many subscribers +- relay: one publisher to one subscriber through many relays + +Communication strategies (comms): +- local: all subs, relays, and pubs are in the SAME process +- shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved +- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + +Variables: +- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- msg_size: nominal message size (bytes) + +Metrics: +- sample_rate: messages/sec at the sink (higher = better) +- data_rate: bytes/sec at the sink (higher = better) +- latency_mean: average send -> receive latency in seconds (lower = better) + +Task: +Summarize performance test results. Explain trade-offs by comms/config, call out anomalies/outliers. +Keep the tone concise, technical, and actionable. +Please format output such that it will display nicely in a terminal output. + +If the rest results are a "PERFORMANCE COMPARISON": +- Metrics are in percentages (100.0 = 100 percent = no change) and do not reflect the ground-truth physical units. +- Summarize key improvements/regressions +- Performance differences +/- 5 percent are likely in the noise. +""" + +def chatgpt_analyze_results( + results_text: str, + *, + prompt: str | None = None, + model: str | None = None, + max_chars: int = 120_000, + temperature: float = 0.2, +) -> str: + """ + Send results + a test description to OpenAI's Responses API and print the analysis. + + Env vars: + - OPENAI_API_KEY (required) + - OPENAI_MODEL (optional; e.g., 'gpt-4o-mini' or a newer model) + """ + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("Please set OPENAI_API_KEY in your environment.") + + model = model or os.getenv("OPENAI_MODEL", "gpt-4o-mini") + + # Keep requests reasonable in size + results_snippet = results_text if len(results_text) <= max_chars else ( + results_text[:max_chars] + "\n\n[...truncated for token budget...]" + ) + + # You can tweak the system instruction here to steer tone/format + system_instruction = "You are a senior performance engineer. Prefer precise, structured analysis." + + user_payload = textwrap.dedent(f"""\ + {prompt or DEFAULT_TEST_DESCRIPTION} + + === BEGIN RESULTS === + {results_snippet} + === END RESULTS === + """) + + body = { + "model": model, + "temperature": temperature, + "input": [ + {"role": "system", "content": [{"type": "text", "text": system_instruction}]}, + {"role": "user", "content": [{"type": "text", "text": user_payload}]}, + ], + } + + req = urllib.request.Request( + "https://api.openai.com/v1/responses", + data=json.dumps(body).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + method="POST", + ) + + with urllib.request.urlopen(req) as resp: + data = json.load(resp) + + # Robust extraction: prefer output_text; fall back to concatenating output content + text = data.get("output_text") + if not text: + parts = [] + for item in data.get("output", []) or []: + for c in item.get("content", []) or []: + if c.get("type") in ("output_text", "text") and "text" in c: + parts.append(c["text"]) + text = "\n".join(parts) if parts else json.dumps(data, indent=2) + + return text diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 0be97744..be0cb1e7 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -6,7 +6,8 @@ from pathlib import Path from ..messagecodec import MessageDecoder -from .envinfo import TestEnvironmentInfo +from .envinfo import TestEnvironmentInfo, format_env_diff +from .ai import chatgpt_analyze_results from .impl import ( TestParameters, Metrics, @@ -74,22 +75,37 @@ def load_perf(perf: Path) -> xr.Dataset: return dataset -def summary(perf_path: Path, baseline_path: Path | None) -> None: +def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> None: """ print perf test results and comparisons to the console """ + output = '' + perf = load_perf(perf_path) - info = perf.attrs['info'] + info: TestEnvironmentInfo = perf.attrs['info'] + output += str(info) + '\n\n' + + if baseline_path is not None: + output += "PERFORMANCE COMPARISON\n\n" baseline = load_perf(baseline_path) perf = (perf / baseline) * 100.0 + baseline_info: TestEnvironmentInfo = baseline.attrs['info'] + output += format_env_diff(info.diff(baseline_info)) + '\n\n' - print(info) + # These raw stats are still valuable to have, but are confusing + # when making relative comparisons + perf = perf.drop_vars(['latency_total', 'num_msgs']) for _, config_ds in perf.groupby('config'): for _, comms_ds in config_ds.groupby('comms'): - print(comms_ds.squeeze().to_dataframe()) - print("\n") - print("\n") + output += str(comms_ds.squeeze().to_dataframe()) + '\n\n' + output += '\n' + + print(output) + + if ai: + print('Querying ChatGPT for AI-assisted analysis of performance test results') + print(chatgpt_analyze_results(output)) def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: @@ -106,8 +122,14 @@ def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: default=None, help="baseline perf test for comparison", ) + p_summary.add_argument( + "--ai", + action="store_true", + help="ask chatgpt for an analysis of the results. requires OPENAI_API_KEY set in environment" + ) p_summary.set_defaults(_handler=lambda ns: summary( perf_path = ns.perf, - baseline_path = ns.baseline + baseline_path = ns.baseline, + ai = ns.ai )) \ No newline at end of file From 6beb4baf90f8b39794f8d56bfa46770f34296b1d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 17:40:44 -0400 Subject: [PATCH 05/88] html analysis output; ai was a bust --- src/ezmsg/util/perf/ai.py | 111 ------------ src/ezmsg/util/perf/analysis.py | 300 ++++++++++++++++++++++++++++++-- src/ezmsg/util/perf/run.py | 23 +++ 3 files changed, 312 insertions(+), 122 deletions(-) delete mode 100644 src/ezmsg/util/perf/ai.py diff --git a/src/ezmsg/util/perf/ai.py b/src/ezmsg/util/perf/ai.py deleted file mode 100644 index a212b81c..00000000 --- a/src/ezmsg/util/perf/ai.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import json -import textwrap -import urllib.request - -DEFAULT_TEST_DESCRIPTION = """\ -You are analyzing performance test results for the ezmsg pub/sub system. - -Configurations (config): -- fanin: Many publishers to one subscriber -- fanout: one publisher to many subscribers -- relay: one publisher to one subscriber through many relays - -Communication strategies (comms): -- local: all subs, relays, and pubs are in the SAME process -- shm / tcp: some clients move to a second process; comms via shared memory / TCP - * fanin: all publishers moved - * fanout: all subscribers moved - * relay: the publisher and all relay nodes moved -- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively - -Variables: -- n_clients: pubs (fanin), subs (fanout), or relays (relay) -- msg_size: nominal message size (bytes) - -Metrics: -- sample_rate: messages/sec at the sink (higher = better) -- data_rate: bytes/sec at the sink (higher = better) -- latency_mean: average send -> receive latency in seconds (lower = better) - -Task: -Summarize performance test results. Explain trade-offs by comms/config, call out anomalies/outliers. -Keep the tone concise, technical, and actionable. -Please format output such that it will display nicely in a terminal output. - -If the rest results are a "PERFORMANCE COMPARISON": -- Metrics are in percentages (100.0 = 100 percent = no change) and do not reflect the ground-truth physical units. -- Summarize key improvements/regressions -- Performance differences +/- 5 percent are likely in the noise. -""" - -def chatgpt_analyze_results( - results_text: str, - *, - prompt: str | None = None, - model: str | None = None, - max_chars: int = 120_000, - temperature: float = 0.2, -) -> str: - """ - Send results + a test description to OpenAI's Responses API and print the analysis. - - Env vars: - - OPENAI_API_KEY (required) - - OPENAI_MODEL (optional; e.g., 'gpt-4o-mini' or a newer model) - """ - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise RuntimeError("Please set OPENAI_API_KEY in your environment.") - - model = model or os.getenv("OPENAI_MODEL", "gpt-4o-mini") - - # Keep requests reasonable in size - results_snippet = results_text if len(results_text) <= max_chars else ( - results_text[:max_chars] + "\n\n[...truncated for token budget...]" - ) - - # You can tweak the system instruction here to steer tone/format - system_instruction = "You are a senior performance engineer. Prefer precise, structured analysis." - - user_payload = textwrap.dedent(f"""\ - {prompt or DEFAULT_TEST_DESCRIPTION} - - === BEGIN RESULTS === - {results_snippet} - === END RESULTS === - """) - - body = { - "model": model, - "temperature": temperature, - "input": [ - {"role": "system", "content": [{"type": "text", "text": system_instruction}]}, - {"role": "user", "content": [{"type": "text", "text": user_payload}]}, - ], - } - - req = urllib.request.Request( - "https://api.openai.com/v1/responses", - data=json.dumps(body).encode("utf-8"), - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - }, - method="POST", - ) - - with urllib.request.urlopen(req) as resp: - data = json.load(resp) - - # Robust extraction: prefer output_text; fall back to concatenating output content - text = data.get("output_text") - if not text: - parts = [] - for item in data.get("output", []) or []: - for c in item.get("content", []) or []: - if c.get("type") in ("output_text", "text") and "text" in c: - parts.append(c["text"]) - text = "\n".join(parts) if parts else json.dumps(data, indent=2) - - return text diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index be0cb1e7..2ddc7062 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -2,12 +2,16 @@ import typing import dataclasses import argparse +import html +import math +import webbrowser +import tempfile from pathlib import Path from ..messagecodec import MessageDecoder from .envinfo import TestEnvironmentInfo, format_env_diff -from .ai import chatgpt_analyze_results +from .run import get_datestamp from .impl import ( TestParameters, Metrics, @@ -17,6 +21,7 @@ try: import xarray as xr + import pandas as pd # xarray depends on pandas except ImportError: ez.logger.error('ezmsg perf analysis requires xarray') raise @@ -74,8 +79,186 @@ def load_perf(perf: Path) -> xr.Dataset: dataset = xr.Dataset(data_vars, attrs = dict(info = info)) return dataset +NOISE_BAND_PCT = 5.0 # +/-5% is "in the noise" for comparisons -def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> None: +def _escape(s: str) -> str: + return html.escape(str(s), quote=True) + +def _env_block(title: str, body: str) -> str: + return f""" +
+

{_escape(title)}

+
{_escape(body).strip()}
+
+ """ + +def _legend_block() -> str: + return """ +
+

Legend

+
    +
  • Comparison mode: values are percentages (100 = no change).
  • +
  • Noise band: ±5% considered negligible (no color).
  • +
  • Green: improvement (↑ sample/data rate, ↓ latency).
  • +
  • Red: regression (↓ sample/data rate, ↑ latency).
  • +
+
+ """ + +def _base_css() -> str: + # Minimal, print-friendly CSS + color scales for cells. + return """ + + """ + +def _color_for_comparison(value: float, metric: str) -> str: + """ + Returns inline CSS background for a comparison % value. + value: e.g., 97.3, 104.8, etc. + For sample_rate/data_rate: improvement > 100 (good). + For latency_mean: improvement < 100 (good). + Noise band ±5% around 100 is neutral. + """ + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return "" + + delta = value - 100.0 + # Determine direction: + is good for sample/data; - is good for latency + if metric in ("sample_rate", "data_rate"): + # positive delta good, negative bad + magnitude = abs(delta) + sign_good = delta > 0 + elif metric == "latency_mean": + # negative delta good (lower latency) + magnitude = abs(delta) + sign_good = delta < 0 + else: + return "" + + # Noise band: keep neutral + if magnitude <= NOISE_BAND_PCT: + return "" + + # Scale 5%..50% across 0..1; clamp + scale = max(0.0, min(1.0, (magnitude - NOISE_BAND_PCT) / 45.0)) + + # Choose hue and lightness; use HSL with gentle saturation + hue = "var(--green)" if sign_good else "var(--red)" + # opacity via alpha blend on lightness via HSLa + # Use saturation ~70%, lightness around 40–50% blended with table bg + alpha = 0.15 + 0.35 * scale # 0.15..0.50 + return f"background-color: hsla({hue}, 70%, 45%, {alpha});" + +def _format_number(x) -> str: + if isinstance(x, (int,)) and not isinstance(x, bool): + return f"{x:d}" + try: + xf = float(x) + except Exception: + return _escape(str(x)) + # Heuristic: for comparison percentages, 1 decimal is nice; for absolute, 3 decimals for latency. + return f"{xf:.3f}" + + +def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None: """ print perf test results and comparisons to the console """ output = '' @@ -84,17 +267,21 @@ def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> No info: TestEnvironmentInfo = perf.attrs['info'] output += str(info) + '\n\n' - + relative = False + env_diff = None if baseline_path is not None: + relative = True output += "PERFORMANCE COMPARISON\n\n" baseline = load_perf(baseline_path) perf = (perf / baseline) * 100.0 baseline_info: TestEnvironmentInfo = baseline.attrs['info'] - output += format_env_diff(info.diff(baseline_info)) + '\n\n' + env_diff = format_env_diff(info.diff(baseline_info)) + output += env_diff + '\n\n' # These raw stats are still valuable to have, but are confusing # when making relative comparisons perf = perf.drop_vars(['latency_total', 'num_msgs']) + df = perf.squeeze().to_dataframe() for _, config_ds in perf.groupby('config'): for _, comms_ds in config_ds.groupby('comms'): @@ -103,9 +290,100 @@ def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> No print(output) - if ai: - print('Querying ChatGPT for AI-assisted analysis of performance test results') - print(chatgpt_analyze_results(output)) + if html: + # Ensure expected columns exist + expected_cols = {"sample_rate", "data_rate", "latency_mean"} + missing = expected_cols - set(df.columns) + if missing: + raise ValueError(f"Missing expected columns in dataset: {missing}") + + # We'll render a table per (config, comms) group. + groups = df.reset_index().sort_values( + by=["config", "comms", "n_clients", "msg_size"] + ).groupby(["config", "comms"], sort=False) + + # Build HTML + parts: list[str] = [] + parts.append("") + parts.append("") + parts.append("ezmsg perf report") + parts.append(_base_css()) + parts.append("
") + + parts.append("
") + parts.append("

ezmsg Performance Report

") + sub = str(perf_path) + if baseline_path is not None: + sub += f' relative to {str(baseline_path)}' + parts.append(f"
{_escape(sub)}
") + parts.append("
") + + if info is not None: + parts.append(_env_block("Test Environment", str(info))) + + if env_diff is not None: + # Show diffs using your helper + parts.append("
") + parts.append("

Environment Differences vs Baseline

") + parts.append(f"
{_escape(env_diff)}
") + parts.append("
") + parts.append(_legend_block()) + + # Render each group + for (config, comms), g in groups: + # Keep only expected columns in order + cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean"] + g = g[cols].copy() + + # String format some columns (msg_size with separators) + g["msg_size"] = g["msg_size"].map(lambda x: f"{int(x):,}" if pd.notna(x) else x) + + # Build table manually so we can inject inline cell styles easily + # (pandas Styler is great but produces bulky HTML; manual keeps it clean) + header = f""" + + + n_clients + msg_size {'' if relative else '(b)'} + sample_rate {'' if relative else '(msgs/s)'} + data_rate {'' if relative else '(MB/s)'} + latency_mean {'' if relative else '(us)'} + + + """ + body_rows: list[str] = [] + for _, row in g.iterrows(): + sr, dr, lt = row["sample_rate"], row["data_rate"], row["latency_mean"] + dr = dr if relative else dr / 2**20 + lt = lt if relative else lt * 1e6 + sr_style = _color_for_comparison(sr, "sample_rate") if relative else "" + dr_style = _color_for_comparison(dr, "data_rate") if relative else "" + lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" + + body_rows.append( + "" + f"{_format_number(row['n_clients'])}" + f"{_escape(row['msg_size'])}" + f"{_format_number(sr)}" + f"{_format_number(dr)}" + f"{_format_number(lt)}" + "" + ) + table_html = f"{header}{''.join(body_rows)}
" + + parts.append( + f"

" + f"{_escape(config)}" + f"{_escape(comms)}" + f"

{table_html}
" + ) + + parts.append("
") + html_text = "".join(parts) + + out_path = Path(f'report_{get_datestamp()}.html') + out_path.write_text(html_text, encoding="utf-8") + webbrowser.open(out_path.resolve().as_uri()) def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: @@ -123,13 +401,13 @@ def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: help="baseline perf test for comparison", ) p_summary.add_argument( - "--ai", - action="store_true", - help="ask chatgpt for an analysis of the results. requires OPENAI_API_KEY set in environment" + "--html", + action = 'store_true', + help = "generate an html output file and render results in browser", ) p_summary.set_defaults(_handler=lambda ns: summary( perf_path = ns.perf, baseline_path = ns.baseline, - ai = ns.ai + html = ns.html )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index d36dd02c..1856f2a6 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -25,6 +25,29 @@ class PerfRunArgs: num_buffers: int def perf_run(args: PerfRunArgs) -> None: + """ + Configurations (config): + - fanin: Many publishers to one subscriber + - fanout: one publisher to many subscribers + - relay: one publisher to one subscriber through many relays + + Communication strategies (comms): + - local: all subs, relays, and pubs are in the SAME process + - shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved + - shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + + Variables: + - n_clients: pubs (fanin), subs (fanout), or relays (relay) + - msg_size: nominal message size (bytes) + + Metrics: + - sample_rate: messages/sec at the sink (higher = better) + - data_rate: bytes/sec at the sink (higher = better) + - latency_mean: average send -> receive latency in seconds (lower = better) + """ msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] From a661ce5bef1d70be552338b9d2a52c92782aed23 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 17:49:44 -0400 Subject: [PATCH 06/88] report tweaks --- src/ezmsg/util/perf/analysis.py | 27 ++++++++++++++++++++++++++- src/ezmsg/util/perf/run.py | 24 ------------------------ 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 2ddc7062..9d33a068 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -5,7 +5,6 @@ import html import math import webbrowser -import tempfile from pathlib import Path @@ -32,6 +31,30 @@ ez.logger.error('ezmsg perf analysis requires numpy') raise +TEST_DESCRIPTION = """ +Configurations (config): +- fanin: many publishers to one subscriber +- fanout: one publisher to many subscribers +- relay: one publisher to one subscriber through many relays + +Communication strategies (comms): +- local: all subs, relays, and pubs are in the SAME process +- shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved +- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + +Variables: +- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- msg_size: nominal message size (bytes) + +Metrics: +- sample_rate: messages/sec at the sink (higher = better) +- data_rate: bytes/sec at the sink (higher = better) +- latency_mean: average send -> receive latency in seconds (lower = better) +""" + def load_perf(perf: Path) -> xr.Dataset: @@ -321,6 +344,8 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> if info is not None: parts.append(_env_block("Test Environment", str(info))) + parts.append(_env_block("Test Details", TEST_DESCRIPTION)) + if env_diff is not None: # Show diffs using your helper parts.append("
") diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 1856f2a6..d7280e51 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -25,30 +25,6 @@ class PerfRunArgs: num_buffers: int def perf_run(args: PerfRunArgs) -> None: - """ - Configurations (config): - - fanin: Many publishers to one subscriber - - fanout: one publisher to many subscribers - - relay: one publisher to one subscriber through many relays - - Communication strategies (comms): - - local: all subs, relays, and pubs are in the SAME process - - shm / tcp: some clients move to a second process; comms via shared memory / TCP - * fanin: all publishers moved - * fanout: all subscribers moved - * relay: the publisher and all relay nodes moved - - shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively - - Variables: - - n_clients: pubs (fanin), subs (fanout), or relays (relay) - - msg_size: nominal message size (bytes) - - Metrics: - - sample_rate: messages/sec at the sink (higher = better) - - data_rate: bytes/sec at the sink (higher = better) - - latency_mean: average send -> receive latency in seconds (lower = better) - """ - msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] comms = [c for c in Communication] From 5464887ca95f58b75ec728ab1ef0defb49f53e0e Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sun, 21 Sep 2025 09:52:18 -0400 Subject: [PATCH 07/88] added test iters --- src/ezmsg/util/perf/analysis.py | 23 ++++++++------- src/ezmsg/util/perf/impl.py | 5 ++++ src/ezmsg/util/perf/run.py | 52 +++++++++++++++++++-------------- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 9d33a068..681c9dac 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -14,6 +14,7 @@ from .impl import ( TestParameters, Metrics, + TestLogEntry, ) import ezmsg.core as ez @@ -59,15 +60,15 @@ def load_perf(perf: Path) -> xr.Dataset: params: typing.List[TestParameters] = [] - results: typing.List[Metrics] = [] + results: typing.List[typing.List[Metrics]] = [] with open(perf, 'r') as perf_f: info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) for line in perf_f: - obj = json.loads(line, cls = MessageDecoder) - params.append(obj['params']) - results.append(obj['results']) + obj: TestLogEntry = json.loads(line, cls = MessageDecoder) + params.append(obj.params) + results.append(obj.results) n_clients_axis = list(sorted(set([p.n_clients for p in params]))) msg_size_axis = list(sorted(set([p.msg_size for p in params]))) @@ -91,19 +92,20 @@ def load_perf(perf: Path) -> xr.Dataset: len(config_axis) )) * np.nan for p, r in zip(params, results): + # tests are run multiple times; get the median value for each metric + values = list(sorted([getattr(v, field.name) for v in r])) + value = values[len(values)//2] m[ n_clients_axis.index(p.n_clients), msg_size_axis.index(p.msg_size), comms_axis.index(p.comms), config_axis.index(p.config) - ] = getattr(r, field.name) + ] = value data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) dataset = xr.Dataset(data_vars, attrs = dict(info = info)) return dataset -NOISE_BAND_PCT = 5.0 # +/-5% is "in the noise" for comparisons - def _escape(s: str) -> str: return html.escape(str(s), quote=True) @@ -121,7 +123,6 @@ def _legend_block() -> str:

Legend

  • Comparison mode: values are percentages (100 = no change).
  • -
  • Noise band: ±5% considered negligible (no color).
  • Green: improvement (↑ sample/data rate, ↓ latency).
  • Red: regression (↓ sample/data rate, ↑ latency).
@@ -232,7 +233,7 @@ def _base_css() -> str: """ -def _color_for_comparison(value: float, metric: str) -> str: +def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 5.0) -> str: """ Returns inline CSS background for a comparison % value. value: e.g., 97.3, 104.8, etc. @@ -257,11 +258,11 @@ def _color_for_comparison(value: float, metric: str) -> str: return "" # Noise band: keep neutral - if magnitude <= NOISE_BAND_PCT: + if magnitude <= noise_band_pct: return "" # Scale 5%..50% across 0..1; clamp - scale = max(0.0, min(1.0, (magnitude - NOISE_BAND_PCT) / 45.0)) + scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0)) # Choose hue and lightness; use HSL with gentle saturation hue = "var(--green)" if sign_good else "var(--red)" diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 58357b8c..5edc0474 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -313,3 +313,8 @@ class TestParameters: duration: float num_buffers: int + +@dataclasses.dataclass +class TestLogEntry: + params: TestParameters + results: typing.List[Metrics] \ No newline at end of file diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index d7280e51..599f3b02 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -8,7 +8,8 @@ from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo from .impl import ( - TestParameters, + TestParameters, + TestLogEntry, perform_test, Communication, CONFIGS, @@ -19,12 +20,12 @@ def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") -@dataclass -class PerfRunArgs: - duration: float - num_buffers: int -def perf_run(args: PerfRunArgs) -> None: +def perf_run( + duration: float, + num_buffers: int, + iters: int, +) -> None: msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] comms = [c for c in Communication] @@ -44,20 +45,22 @@ def perf_run(args: PerfRunArgs) -> None: n_clients = n_clients, config = config.__name__, comms = comms.value, - duration = args.duration, - num_buffers = args.num_buffers + duration = duration, + num_buffers = num_buffers ) - results = perform_test( - n_clients = n_clients, - duration = args.duration, - msg_size = msg_size, - buffers = args.num_buffers, - comms = comms, - config = config, - ) - - output = dict( + results = [ + perform_test( + n_clients = n_clients, + duration = duration, + msg_size = msg_size, + buffers = num_buffers, + comms = comms, + config = config, + ) for _ in range(iters) + ] + + output = TestLogEntry( params = params, results = results ) @@ -79,10 +82,15 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: default=32, help="shared memory buffers (default = 32)", ) + p_run.add_argument( + "--iters", "-i", + type = int, + default = 3, + help = "number of times to run each test" + ) p_run.set_defaults(_handler=lambda ns: perf_run( - PerfRunArgs( - duration = ns.duration, - num_buffers = ns.num_buffers - ) + duration = ns.duration, + num_buffers = ns.num_buffers, + iters = ns.iters )) \ No newline at end of file From 460ca9d230dcc55bc66dd1bf6236027e7296d734 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 13:06:30 -0400 Subject: [PATCH 08/88] more cmdline --- src/ezmsg/util/perf/impl.py | 8 ++- src/ezmsg/util/perf/run.py | 99 +++++++++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 17 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 5edc0474..6bb0ded7 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -190,7 +190,13 @@ def relay(config: ConfigSettings) -> Configuration: return relays, connections -CONFIGS: typing.Iterable[Configurator] = [fanin, fanout, relay] +CONFIGS: typing.Mapping[str, Configurator] = { + c.__name__: c for c in [ + fanin, + fanout, + relay + ] +} class Communication(enum.StrEnum): LOCAL = "local" diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 599f3b02..094389c2 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -2,8 +2,7 @@ import datetime import itertools import argparse - -from dataclasses import dataclass +import typing from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo @@ -17,6 +16,10 @@ import ezmsg.core as ez +DEFAULT_MSG_SIZES = [2 ** exp for exp in range(4, 25, 8)] +DEFAULT_N_CLIENTS = [2 ** exp for exp in range(0, 6, 2)] +DEFAULT_COMMS = [c for c in Communication] + def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") @@ -25,38 +28,62 @@ def perf_run( duration: float, num_buffers: int, iters: int, + msg_sizes: typing.Iterable[int] | None, + n_clients: typing.Iterable[int] | None, + comms: typing.Iterable[str] | None, + configs: typing.Iterable[str] | None, ) -> None: - msg_sizes = [2 ** exp for exp in range(4, 25, 8)] - n_clients = [2 ** exp for exp in range(0, 6, 2)] - comms = [c for c in Communication] - - test_list = list(itertools.product(msg_sizes, n_clients, CONFIGS, comms)) + + if n_clients is None: + n_clients = DEFAULT_N_CLIENTS + if any(c <= 0 for c in n_clients): + ez.logger.error('All tests must have >0 clients') + return + + if msg_sizes is None: + msg_sizes = DEFAULT_MSG_SIZES + if any(s <= 0 for s in msg_sizes): + ez.logger.error('All msg_sizes must be >0 bytes') + + try: + communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] + except ValueError: + ez.logger.error(f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}") + return + + try: + configurators = list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs] + except ValueError: + ez.logger.error(f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}") + return + + test_list = list(itertools.product(msg_sizes, n_clients, configurators, communications)) with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") - for test_idx, (msg_size, n_clients, config, comms) in enumerate(test_list): + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): ez.logger.info(f"RUNNING TEST {test_idx + 1} / {len(test_list)} ({(test_idx / len(test_list)) * 100.0:0.2f} %)") params = TestParameters( msg_size = msg_size, - n_clients = n_clients, - config = config.__name__, - comms = comms.value, + n_clients = clients, + config = conf.__name__, + comms = comm.value, duration = duration, num_buffers = num_buffers ) results = [ perform_test( - n_clients = n_clients, + n_clients = clients, duration = duration, msg_size = msg_size, buffers = num_buffers, - comms = comms, - config = config, + comms = comm, + config = conf, ) for _ in range(iters) ] @@ -70,27 +97,67 @@ def perf_run( def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: p_run = subparsers.add_parser("run", help="run performance test") + p_run.add_argument( "--duration", type=float, default=2.0, help="individual test duration in seconds (default = 2.0)", ) + p_run.add_argument( "--num-buffers", type=int, default=32, help="shared memory buffers (default = 32)", ) + p_run.add_argument( "--iters", "-i", type = int, default = 3, - help = "number of times to run each test" + help = "number of times to run each test (default = 3)" ) + p_run.add_argument( + "--msg-sizes", + type = int, + default = None, + nargs = "*", + help = f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})" + ) + + p_run.add_argument( + "--n-clients", + type = int, + default = None, + nargs = "*", + help = f"number of clients (default = {DEFAULT_N_CLIENTS})" + ) + + p_run.add_argument( + "--comms", + type = str, + default = None, + nargs = "*", + help = f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})" + ) + + p_run.add_argument( + "--configs", + type = str, + default = None, + nargs = "*", + help = f"configurations to test (default = {[c for c in CONFIGS]})" + ) + + p_run.set_defaults(_handler=lambda ns: perf_run( duration = ns.duration, num_buffers = ns.num_buffers, - iters = ns.iters + iters = ns.iters, + msg_sizes = ns.msg_sizes, + n_clients = ns.n_clients, + comms = ns.comms, + configs = ns.configs, )) \ No newline at end of file From 173c09609941b72b80aae478c674da25904f9035 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 15:00:11 -0400 Subject: [PATCH 09/88] changing sample_rate calculation --- src/ezmsg/util/perf/impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 6bb0ded7..8421a51b 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -264,10 +264,10 @@ def perform_test( process_components = process_components, ) - return calculate_metrics(sink) + return calculate_metrics(sink, duration) -def calculate_metrics(sink: LoadTestSink) -> Metrics: +def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: # Log some useful summary statistics min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) @@ -286,7 +286,7 @@ def calculate_metrics(sink: LoadTestSink) -> Metrics: num_samples = len(sink.STATE.received_data) ez.logger.info(f"Samples received: {num_samples}") - sample_rate = num_samples / (max_timestamp - min_timestamp) + sample_rate = num_samples / duration ez.logger.info(f"Sample rate: {sample_rate} Hz") latency_mean = total_latency / num_samples ez.logger.info(f"Mean latency: {latency_mean} s") From d17a4fe17dfc229d79c6e4ded3473d4a4639dce7 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 16:52:50 -0400 Subject: [PATCH 10/88] fix: n-clients = 0 is useful and works --- src/ezmsg/util/perf/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 094389c2..3ff70e91 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -36,8 +36,8 @@ def perf_run( if n_clients is None: n_clients = DEFAULT_N_CLIENTS - if any(c <= 0 for c in n_clients): - ez.logger.error('All tests must have >0 clients') + if any(c < 0 for c in n_clients): + ez.logger.error('All tests must have >=0 clients') return if msg_sizes is None: From 01212d3608d4dbdcba1e1ea34c8d556e6af2230f Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 19:10:46 -0400 Subject: [PATCH 11/88] better support for msg_size = 0 and n_clients = 0 --- src/ezmsg/util/perf/impl.py | 10 ++++++---- src/ezmsg/util/perf/run.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 8421a51b..faec5d19 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -183,10 +183,12 @@ def relay(config: ConfigSettings) -> Configuration: connections: ez.NetworkDefinition = [] relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] - connections.append((config.source.OUTPUT, relays[0].INPUT)) - for from_relay, to_relay in zip(relays[:-1], relays[1:]): - connections.append((from_relay.OUTPUT, to_relay.INPUT)) - connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + if len(relays): + connections.append((config.source.OUTPUT, relays[0].INPUT)) + for from_relay, to_relay in zip(relays[:-1], relays[1:]): + connections.append((from_relay.OUTPUT, to_relay.INPUT)) + connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + else: connections.append((config.source.OUTPUT, config.sink.INPUT)) return relays, connections diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 3ff70e91..80bf80cb 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -42,8 +42,8 @@ def perf_run( if msg_sizes is None: msg_sizes = DEFAULT_MSG_SIZES - if any(s <= 0 for s in msg_sizes): - ez.logger.error('All msg_sizes must be >0 bytes') + if any(s < 0 for s in msg_sizes): + ez.logger.error('All msg_sizes must be >=0 bytes') try: communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] From 30234ae07827666c5f26645a33fd31179de584a1 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 19:49:13 -0400 Subject: [PATCH 12/88] also force_tcp on tcp_spread oops --- src/ezmsg/util/perf/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index faec5d19..114a8b7a 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -220,7 +220,7 @@ def perform_test( dynamic_size = int(msg_size), duration = duration, buffers = buffers, - force_tcp = (comms == Communication.TCP), + force_tcp = (comms in (Communication.TCP, Communication.TCP_SPREAD)), ) source = LoadTestSource(settings) From e8d1a3dbe1d9cd6bcb7853ca557bb684f04f1934 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 20:41:09 -0400 Subject: [PATCH 13/88] viztracer is useful for profiling --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 961271b4..c0888f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dev = [ {include-group = "lint"}, {include-group = "test"}, "pre-commit>=4.3.0", + "viztracer>=1.0.4", ] lint = [ "flake8>=7.3.0", From a7913656790da097fbbd119036923ed185df6087 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 09:22:54 -0400 Subject: [PATCH 14/88] fix for rare exception on system shutdown --- src/ezmsg/core/subclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 50e1aa0c..9a28d9ee 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -210,7 +210,7 @@ async def _handle_publisher( self._incoming.put_nowait((id, msg_id)) - except (ConnectionResetError, BrokenPipeError): + except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): logger.debug(f"connection fail: sub:{self.id} -> pub:{id}") finally: From 77e38e9e48c2e9b289d424ff4cd78a46ec575e27 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 13:09:54 -0400 Subject: [PATCH 15/88] added median latency --- src/ezmsg/util/perf/impl.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 114a8b7a..99d11046 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -41,6 +41,7 @@ class Metrics: num_msgs: int sample_rate: float latency_mean: float + latency_median: float latency_total: float data_rate: float @@ -274,12 +275,11 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: # Log some useful summary statistics min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) - total_latency = abs( - sum( - receive_timestamp - send_timestamp - for send_timestamp, receive_timestamp, _ in sink.STATE.received_data - ) - ) + latency = [ + receive_timestamp - send_timestamp + for send_timestamp, receive_timestamp, _ in sink.STATE.received_data + ] + total_latency = abs(sum(latency)) counters = list(sorted(t[2] for t in sink.STATE.received_data)) dropped_samples = sum( @@ -291,7 +291,9 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: sample_rate = num_samples / duration ez.logger.info(f"Sample rate: {sample_rate} Hz") latency_mean = total_latency / num_samples + latency_median = list(sorted(latency))[len(latency) // 2] ez.logger.info(f"Mean latency: {latency_mean} s") + ez.logger.info(f"Median latency: {latency_median} s") ez.logger.info(f"Total latency: {total_latency} s") total_data = num_samples * sink.SETTINGS.dynamic_size @@ -307,6 +309,7 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: num_msgs = num_samples, sample_rate = sample_rate, latency_mean = latency_mean, + latency_median = latency_median, latency_total = total_latency, data_rate = data_rate ) From 4693d5cfd5d04d52300468ad77297ab6a936c3c4 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 15:30:46 -0400 Subject: [PATCH 16/88] added median latency to performance report --- src/ezmsg/util/perf/analysis.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 681c9dac..ed729187 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -246,11 +246,11 @@ def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 5.0 delta = value - 100.0 # Determine direction: + is good for sample/data; - is good for latency - if metric in ("sample_rate", "data_rate"): + if 'rate' in metric: # positive delta good, negative bad magnitude = abs(delta) sign_good = delta > 0 - elif metric == "latency_mean": + elif 'latency' in metric: # negative delta good (lower latency) magnitude = abs(delta) sign_good = delta < 0 @@ -316,7 +316,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> if html: # Ensure expected columns exist - expected_cols = {"sample_rate", "data_rate", "latency_mean"} + expected_cols = {"sample_rate", "data_rate", "latency_mean", "latency_median"} missing = expected_cols - set(df.columns) if missing: raise ValueError(f"Missing expected columns in dataset: {missing}") @@ -358,7 +358,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> # Render each group for (config, comms), g in groups: # Keep only expected columns in order - cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean"] + cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean", "latency_median"] g = g[cols].copy() # String format some columns (msg_size with separators) @@ -374,17 +374,20 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> sample_rate {'' if relative else '(msgs/s)'} data_rate {'' if relative else '(MB/s)'} latency_mean {'' if relative else '(us)'} + latency_median {'' if relative else '(us)'} """ body_rows: list[str] = [] for _, row in g.iterrows(): - sr, dr, lt = row["sample_rate"], row["data_rate"], row["latency_mean"] + sr, dr, lt, lm = row["sample_rate"], row["data_rate"], row["latency_mean"], row["latency_median"] dr = dr if relative else dr / 2**20 lt = lt if relative else lt * 1e6 + lm = lm if relative else lm * 1e6 sr_style = _color_for_comparison(sr, "sample_rate") if relative else "" dr_style = _color_for_comparison(dr, "data_rate") if relative else "" lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" + lm_style = _color_for_comparison(lm, "latency_median") if relative else "" body_rows.append( "" @@ -393,6 +396,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> f"{_format_number(sr)}" f"{_format_number(dr)}" f"{_format_number(lt)}" + f"{_format_number(lm)}" "" ) table_html = f"{header}{''.join(body_rows)}
" From e8f49947cb71eddb1b0f8970bca3950770ad0563 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 16:07:10 -0400 Subject: [PATCH 17/88] try using time.time --- src/ezmsg/util/perf/impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 99d11046..9fdfe740 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -75,16 +75,16 @@ async def initialize(self) -> None: @ez.publisher(OUTPUT) async def publish(self) -> typing.AsyncGenerator: ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.perf_counter() + start_time = time.time() while self.running: - current_time = time.perf_counter() + current_time = time.time() if current_time - start_time >= self.SETTINGS.duration: break yield ( self.OUTPUT, LoadTestSample( - _timestamp=time.perf_counter(), + _timestamp=time.time(), counter=self.counter, dynamic_data=np.zeros( int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 @@ -133,7 +133,7 @@ async def receive(self, sample: LoadTestSample) -> None: f"{sample.counter - counter - 1} samples skipped!" ) self.STATE.received_data.append( - (sample._timestamp, time.perf_counter(), sample.counter) + (sample._timestamp, time.time(), sample.counter) ) self.STATE.counters[sample.key] = sample.counter From 59d2c5997133fddfd4050a6f6f51657725cfa4aa Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 13 Aug 2025 11:42:06 -0400 Subject: [PATCH 18/88] fix: graphserver port no-longer lingers --- src/ezmsg/core/netprotocol.py | 32 +++++++++++++++++++------------- src/ezmsg/core/server.py | 8 ++++---- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 4024324b..f08597b3 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -183,16 +183,22 @@ def create_socket( if port is not None: sock.bind((host, port)) - return sock - - port = start_port - while port <= max_port: - if port not in ignore_ports: - try: - sock.bind((host, port)) - return sock - except OSError: - pass - port += 1 - - raise IOError("Failed to bind socket; no free ports") + + else: + bound = False + port = start_port + while port <= max_port: + if port not in ignore_ports: + try: + sock.bind((host, port)) + bound = True + break + except OSError: + pass + port += 1 + + if not bound: + raise IOError("Failed to bind socket; no free ports") + + sock.setblocking(False) + return sock diff --git a/src/ezmsg/core/server.py b/src/ezmsg/core/server.py index f925a8f2..42020483 100644 --- a/src/ezmsg/core/server.py +++ b/src/ezmsg/core/server.py @@ -71,7 +71,8 @@ async def _serve(self) -> None: async def monitor_shutdown() -> None: await self._loop.run_in_executor(None, self._shutdown.wait) - await close_server(server) + server.close() + await server.wait_closed() monitor_task = self._loop.create_task(monitor_shutdown()) @@ -82,9 +83,8 @@ async def monitor_shutdown() -> None: finally: await self.shutdown() - monitor_task.cancel() - with suppress(asyncio.CancelledError): - await monitor_task + self._shutdown.set() + await monitor_task async def setup(self) -> None: ... From 79b4339cd1372cb4d13e1b3130129b6a6b914231 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 13 Aug 2025 11:44:13 -0400 Subject: [PATCH 19/88] comments to self --- src/ezmsg/core/graphserver.py | 10 ++++++-- src/ezmsg/core/shm.py | 43 ++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index fad7515b..896443ed 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -98,10 +98,12 @@ async def api( info: Optional[SHMInfo] = None if req == Command.SHM_CREATE.value: + # TODO: UUID num_buffers = await read_int(reader) buf_size = await read_int(reader) - # Create segment + # Create segment + # TODO: Move me into shm.py shm = SharedMemory(size=num_buffers * buf_size, create=True) shm.buf[:] = b"0" * len(shm.buf) # Guarantee zeros shm.buf[0:8] = uint64_to_bytes(num_buffers) @@ -393,10 +395,14 @@ async def get_formatted_graph( return formatted_graph async def create_shm( - self, num_buffers: int, buf_size: int = DEFAULT_SHM_SIZE + self, + # TODO: add UUID parameter + num_buffers: int, + buf_size: int = DEFAULT_SHM_SIZE ) -> SHMContext: reader, writer = await self.open_connection() writer.write(Command.SHM_CREATE.value) + # TODO: serialize UUID writer.write(uint64_to_bytes(num_buffers)) writer.write(uint64_to_bytes(buf_size)) await writer.drain() diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index adcab4f9..41a3fdb2 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -2,6 +2,7 @@ import logging import typing +from uuid import UUID from dataclasses import dataclass, field from contextlib import contextmanager, suppress from multiprocessing import resource_tracker @@ -16,13 +17,31 @@ _std_register = resource_tracker.register +""" + ezmsg shared memory format: + TODO: [UUID] + [ UINT64 -- n_buffers ] + [ UINT64 -- buf_size ] + [ buf_size - 16 -- buf0 data_block ] + ... + [ 0x00 * 16 bytes -- reserved (header) ] + [ buf_size - 16 -- buf1 data_block ] + ... + * n_buffers defines the number of shared memory buffers + * buf_size defines the size of each shared memory buffer (including 16 bytes for header) + * data_block is the remaining memory in this buffer which contains message information + + This format repeats itself for every buffer in the SharedMemory block. +""" + + +# TODO: Update with a more recent monkeypatch def _ignore_shm(name, rtype): if rtype == "shared_memory": return return resource_tracker._resource_tracker.register(self, name, rtype) # noqa: F821 - @contextmanager def _untracked_shm() -> typing.Generator[None, None, None]: """Disable SHM tracking within context - https://bugs.python.org/issue38119""" @@ -35,21 +54,6 @@ class SHMContext: """ SHMContext manages the memory map of a block of shared memory, and exposes memoryview objects for reading and writing - - ezmsg shared memory format: - [ UINT64 -- n_buffers ] - [ UINT64 -- buf_size ] - [ buf_size - 16 -- buf0 data_block ] - ... - [ 0x00 * 16 bytes -- reserved (header) ] - [ buf_size - 16 -- buf1 data_block ] - ... - - * n_buffers defines the number of shared memory buffers - * buf_size defines the size of each shared memory buffer (including 16 bytes for header) - * data_block is the remaining memory in this buffer which contains message information - - This format repeats itself for every buffer in the SharedMemory block. """ _shm: SharedMemory @@ -78,6 +82,7 @@ def __init__(self, name: str) -> None: slice(*seg) for seg in zip(buf_data_block_starts, buf_stops) ] + # TODO: Get rid of underscore and rename to attach @classmethod def _create( cls, shm_name: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter @@ -147,6 +152,12 @@ class SHMInfo: shm: SharedMemory leases: typing.Set["asyncio.Task[None]"] = field(default_factory=set) + @classmethod + def create( + cls, uuid: UUID, num_buffers: int, buf_size: int + ) -> "SHMInfo": + ... + def lease( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "asyncio.Task[None]": From 77eca8be62beb6be53e82db0f239db2c8f7f055d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 18 Jul 2025 08:17:36 -0400 Subject: [PATCH 20/88] write optimization for tcp and SHM --- src/ezmsg/core/pubclient.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 1b6b111f..f1880726 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -268,7 +268,6 @@ async def broadcast(self, obj: Any) -> None: if not self._force_tcp and sub.shm_access: if sub.pid == self.pid: sub.writer.write(Command.TX_LOCAL.value + msg_id_bytes) - else: try: # Push cache to shm (if not already there) @@ -288,21 +287,20 @@ async def broadcast(self, obj: Any) -> None: self._shm = new_shm MessageCache[self.id].push(self._msg_id, self._shm) - sub.writer.write(Command.TX_SHM.value) - sub.writer.write(msg_id_bytes) - sub.writer.write(encode_str(self._shm.name)) + sub.writer.write(Command.TX_SHM.value + msg_id_bytes + encode_str(self._shm.name)) else: with MessageMarshal.serialize(self._msg_id, obj) as ser_obj: total_size, header, buffers = ser_obj total_size_bytes = uint64_to_bytes(total_size) - sub.writer.write(Command.TX_TCP.value) - sub.writer.write(msg_id_bytes) - sub.writer.write(total_size_bytes) - sub.writer.write(header) - for buffer in buffers: - sub.writer.write(buffer) + sub.writer.write( + Command.TX_TCP.value + \ + msg_id_bytes + \ + total_size_bytes + \ + header + \ + b''.join([buffer for buffer in buffers]) + ) try: await sub.writer.drain() From 48ab00dc2e67e2a1d7768aff31151564fd209bd8 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 18 Jul 2025 08:18:36 -0400 Subject: [PATCH 21/88] formatting --- src/ezmsg/core/pubclient.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index f1880726..7c4fff1a 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -287,7 +287,11 @@ async def broadcast(self, obj: Any) -> None: self._shm = new_shm MessageCache[self.id].push(self._msg_id, self._shm) - sub.writer.write(Command.TX_SHM.value + msg_id_bytes + encode_str(self._shm.name)) + sub.writer.write( + Command.TX_SHM.value + \ + msg_id_bytes + \ + encode_str(self._shm.name) + ) else: with MessageMarshal.serialize(self._msg_id, obj) as ser_obj: From 776137044b78e7156553517887bf1ca736f89641 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 8 Aug 2025 08:32:11 -0400 Subject: [PATCH 22/88] tcp write optimization --- src/ezmsg/core/pubclient.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 7c4fff1a..f39404cc 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -75,8 +75,7 @@ async def create( writer.write(Command.PUBLISH.value) id = UUID(await read_str(reader)) pub = cls(id, topic, graph_service, **kwargs) - writer.write(uint64_to_bytes(pub.pid)) - writer.write(encode_str(pub.topic)) + writer.write(uint64_to_bytes(pub.pid) + encode_str(pub.topic)) pub._shm = await graph_service.create_shm(pub._num_buffers, buf_size) start_port = int( From 2e09d62cc6854f68e51a26cef7c72f12ec72bccb Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 20 Aug 2025 12:42:31 -0400 Subject: [PATCH 23/88] backend rework: channels --- src/ezmsg/core/backendprocess.py | 6 +- src/ezmsg/core/graphserver.py | 155 +++++++++------- src/ezmsg/core/messagecache.py | 77 -------- src/ezmsg/core/messagechannel.py | 293 +++++++++++++++++++++++++++++++ src/ezmsg/core/messagemarshal.py | 19 +- src/ezmsg/core/netprotocol.py | 20 ++- src/ezmsg/core/pubclient.py | 220 ++++++++++++----------- src/ezmsg/core/server.py | 4 +- src/ezmsg/core/shm.py | 19 +- src/ezmsg/core/subclient.py | 200 +++++---------------- 10 files changed, 592 insertions(+), 421 deletions(-) delete mode 100644 src/ezmsg/core/messagecache.py create mode 100644 src/ezmsg/core/messagechannel.py diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 40668520..98a46f75 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -22,7 +22,7 @@ from .graphserver import GraphService from .pubclient import Publisher from .subclient import Subscriber -from .messagecache import MessageCache +from .messagechannel import CHANNELS from abc import abstractmethod from typing import ( @@ -258,8 +258,8 @@ async def shutdown_units() -> None: asyncio.run_coroutine_threadsafe(shutdown_units(), loop=loop).result() - for cache in MessageCache.values(): - cache.clear() + # for cache in MessageCache.values(): + # cache.clear() asyncio.run_coroutine_threadsafe(context.revert(), loop=loop).result() diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 896443ed..125d8569 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -3,7 +3,7 @@ import pickle from contextlib import suppress from typing import Dict, List, Optional, Tuple -from uuid import UUID, getnode, uuid1 +from uuid import UUID, uuid1 from . import __version__ from .dag import DAG, CyclicException @@ -14,6 +14,7 @@ ClientInfo, SubscriberInfo, PublisherInfo, + ChannelInfo, AddressType, close_stream_writer, encode_str, @@ -25,7 +26,7 @@ DEFAULT_SHM_SIZE, ) from .server import ServiceManager, ThreadedAsyncServer -from .shm import SharedMemory, SHMContext, SHMInfo +from .shm import SHMContext, SHMInfo logger = logging.getLogger("ezmsg") @@ -34,13 +35,10 @@ class GraphServer(ThreadedAsyncServer): """ Pub-Sub Directed Acyclic Graph - Running as a process: start() and stop() - Running as a thread: start_server(), stop_server(), join_server() """ graph: DAG clients: Dict[UUID, ClientInfo] - node: int shms: Dict[str, SHMInfo] _client_tasks: Dict[UUID, "asyncio.Task[None]"] @@ -52,7 +50,6 @@ def __init__(self) -> None: self.clients = dict() self._client_tasks = dict() - self.node = getnode() self.shms = dict() async def setup(self) -> None: @@ -76,7 +73,6 @@ async def api( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: try: - node = await read_int(reader) writer.write(encode_str(__version__)) await writer.drain() @@ -95,7 +91,7 @@ async def api( return elif req in [Command.SHM_CREATE.value, Command.SHM_ATTACH.value]: - info: Optional[SHMInfo] = None + shm_info: Optional[SHMInfo] = None if req == Command.SHM_CREATE.value: # TODO: UUID @@ -103,60 +99,97 @@ async def api( buf_size = await read_int(reader) # Create segment - # TODO: Move me into shm.py - shm = SharedMemory(size=num_buffers * buf_size, create=True) - shm.buf[:] = b"0" * len(shm.buf) # Guarantee zeros - shm.buf[0:8] = uint64_to_bytes(num_buffers) - shm.buf[8:16] = uint64_to_bytes(buf_size) - info = SHMInfo(shm) - self.shms[shm.name] = info - logger.debug(f"created {shm.name}") + shm_info = SHMInfo.create(num_buffers, buf_size) + self.shms[shm_info.shm.name] = shm_info + logger.debug(f"created {shm_info.shm.name}") elif req == Command.SHM_ATTACH.value: shm_name = await read_str(reader) - info = self.shms.get(shm_name, None) + shm_info = self.shms.get(shm_name, None) - if info is None: + if shm_info is None: await close_stream_writer(writer) return writer.write(Command.COMPLETE.value) - writer.write(encode_str(info.shm.name)) + writer.write(encode_str(shm_info.shm.name)) with suppress(asyncio.CancelledError): - await info.lease(reader, writer) + await shm_info.lease(reader, writer) else: # We only want to handle one command at a time async with self._command_lock: - if req in [Command.SUBSCRIBE.value, Command.PUBLISH.value]: - id = uuid1(node=node) - writer.write(encode_str(str(id))) - - pid = await read_int(reader) - topic = await read_str(reader) - - if req == Command.SUBSCRIBE.value: - info = SubscriberInfo(id, writer, pid, topic) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) - ) - iface = writer.transport.get_extra_info("sockname")[0] - await self._notify_subscriber(info, iface) - - elif req == Command.PUBLISH.value: - address = await Address.from_stream(reader) - info = PublisherInfo(id, writer, pid, topic, address) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) - ) - iface = writer.transport.get_extra_info("peername")[0] - for sub in self._downstream_subs(info.topic): - await self._notify_subscriber(sub, iface) + if req in [ + Command.SUBSCRIBE.value, + Command.PUBLISH.value, + Command.CHANNEL.value + ]: + id = uuid1() + + if req in [Command.SUBSCRIBE.value, Command.PUBLISH.value]: + topic = await read_str(reader) + + if req == Command.SUBSCRIBE.value: + info = SubscriberInfo(id, writer, topic) + self.clients[id] = info + self._client_tasks[id] = asyncio.create_task( + self._handle_client(id, reader, writer) + ) + + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(id))) + + await self._notify_subscriber(info) + + elif req == Command.PUBLISH.value: + address = await Address.from_stream(reader) + info = PublisherInfo(id, writer, topic, address) + self.clients[id] = info + self._client_tasks[id] = asyncio.create_task( + self._handle_client(id, reader, writer) + ) + + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(id))) + + for sub in self._downstream_subs(info.topic): + await self._notify_subscriber(sub) + + + elif req == Command.CHANNEL.value: + pub_id_str = await read_str(reader) + pub_id = UUID(pub_id_str) + + pub_addr = None + try: + pub_info = self.clients[pub_id] + if isinstance(pub_info, PublisherInfo): + # assemble an address the channel should be able to resolve + port = pub_info.address.port + iface = pub_info.writer.transport.get_extra_info("peername")[0] + pub_addr = Address(iface, port) + else: + logger.warning(f"Connecting channel requested {type(pub_info)}") + + except KeyError: + logger.warning(f"Connecting channel requested non-existent publisher {pub_id=}") + + if pub_addr is not None: + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(id))) + pub_addr.to_stream(writer) + + info = ChannelInfo(id, writer, pub_id) + self.clients[id] = info + self._client_tasks[id] = asyncio.create_task( + self._handle_client(id, reader, writer) + ) + else: + # Error, drop connection + await close_stream_writer(writer) + return - writer.write(Command.COMPLETE.value) await writer.drain() return @@ -237,23 +270,12 @@ async def _handle_client( del self.clients[id] await close_stream_writer(writer) - async def _notify_subscriber( - self, sub: SubscriberInfo, iface: Optional[str] = None - ) -> None: + async def _notify_subscriber(self, sub: SubscriberInfo) -> None: try: - notification = [] - for pub in self._upstream_pubs(sub.topic): - address = pub.address - if address is not None: - if iface is not None: - notification.append(f"{str(pub.id)}@{iface}:{address.port}") - else: - notification.append( - f"{str(pub.id)}@{address.host}:{address.port}" - ) + pub_ids = [str(pub.id) for pub in self._upstream_pubs(sub.topic)] async with sub.sync_writer() as writer: - notify_str = ",".join(notification) + notify_str = ",".join(pub_ids) writer.write(Command.UPDATE.value) writer.write(encode_str(notify_str)) @@ -269,6 +291,11 @@ def _subscribers(self) -> List[SubscriberInfo]: return [ info for info in self.clients.values() if isinstance(info, SubscriberInfo) ] + + def _channels(self) -> List[ChannelInfo]: + return [ + info for info in self.clients.values() if isinstance(info, ChannelInfo) + ] def _upstream_pubs(self, topic: str) -> List[PublisherInfo]: """Given a topic, return a set of all publisher IDs upstream of that topic""" @@ -292,8 +319,6 @@ async def open_connection( self, ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: reader, writer = await super().open_connection() - writer.write(uint64_to_bytes(getnode())) - await writer.drain() server_version = await read_str(reader) if server_version != __version__: logger.warning( @@ -412,7 +437,7 @@ async def create_shm( raise ValueError("Error creating SHM segment") shm_name = await read_str(reader) - return SHMContext._create(shm_name, reader, writer) + return SHMContext.attach(shm_name, reader, writer) async def attach_shm(self, name: str) -> SHMContext: reader, writer = await self.open_connection() @@ -425,4 +450,4 @@ async def attach_shm(self, name: str) -> SHMContext: raise ValueError("Invalid SHM Name") shm_name = await read_str(reader) - return SHMContext._create(shm_name, reader, writer) + return SHMContext.attach(shm_name, reader, writer) \ No newline at end of file diff --git a/src/ezmsg/core/messagecache.py b/src/ezmsg/core/messagecache.py deleted file mode 100644 index 3eb9c285..00000000 --- a/src/ezmsg/core/messagecache.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging - -from uuid import UUID -from contextlib import contextmanager - -from .shm import SHMContext -from .messagemarshal import MessageMarshal, UninitializedMemory - -from typing import Dict, Any, Optional, List, Generator - -logger = logging.getLogger("ezmsg") - - -class CacheMiss(Exception): ... - - -class Cache: - """shared-memory backed cache for objects""" - - num_buffers: int - cache: List[Any] - cache_id: List[Optional[int]] - - def __init__(self, num_buffers: int) -> None: - self.num_buffers = num_buffers - self.cache_id = [None] * self.num_buffers - self.cache = [None] * self.num_buffers - - def put(self, msg_id: int, msg: Any) -> None: - """put an object into cache""" - buf_idx = msg_id % self.num_buffers - self.cache_id[buf_idx] = msg_id - self.cache[buf_idx] = msg - - def push(self, msg_id: int, shm: SHMContext) -> None: - """push an object from cache into shm""" - if self.num_buffers != shm.num_buffers: - raise ValueError("shm has incorrect number of buffers") - - buf_idx = msg_id % self.num_buffers - with shm.buffer(buf_idx) as mem: - shm_msg_id = MessageMarshal.msg_id(mem) - if shm_msg_id != msg_id: - with self.get(msg_id) as obj: - MessageMarshal.to_mem(msg_id, obj, mem) - - @contextmanager - def get( - self, msg_id: int, shm: Optional[SHMContext] = None - ) -> Generator[Any, None, None]: - """get object from cache; if not in cache and shm provided -- get from shm""" - - buf_idx = msg_id % self.num_buffers - if self.cache_id[buf_idx] == msg_id: - yield self.cache[buf_idx] - - else: - if shm is None: - raise CacheMiss - - with shm.buffer(buf_idx, readonly=True) as mem: - try: - if MessageMarshal.msg_id(mem) != msg_id: - raise CacheMiss - except UninitializedMemory: - raise CacheMiss - - with MessageMarshal.obj_from_mem(mem) as obj: - yield obj - - def clear(self): - self.cache_id = [None] * self.num_buffers - self.cache = [None] * self.num_buffers - - -# FIXME: This should be made thread-safe in the future -MessageCache: Dict[UUID, Cache] = dict() diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py new file mode 100644 index 00000000..a60ec5d0 --- /dev/null +++ b/src/ezmsg/core/messagechannel.py @@ -0,0 +1,293 @@ +import os +import asyncio +import typing +import logging + +from uuid import UUID +from contextlib import contextmanager + +from .shm import SHMContext +from .messagemarshal import MessageMarshal +from .backpressure import Backpressure + +from .graphserver import GraphService +from .netprotocol import ( + Command, + Address, + AddressType, + read_str, + read_int, + uint64_to_bytes, + encode_str, + close_stream_writer, + GRAPHSERVER_ADDR +) + +logger = logging.getLogger("ezmsg") + + +class CacheMiss(Exception): ... + + +class _Channel: + """cache-backed message channel for a particular publisher""" + + id: UUID + pub_id: UUID + pid: int + topic: str + + num_buffers: int + cache: typing.List[typing.Any] + cache_id: typing.List[int | None] + shm: SHMContext | None + sub_queues: typing.Dict[UUID, asyncio.Queue[typing.Tuple[UUID, int]]] + backpressure: Backpressure + + _graph_task: asyncio.Task[None] + _pub_task: asyncio.Task[None] + _pub_writer: asyncio.StreamWriter + _graph_address: AddressType | None + + def __init__( + self, + id: UUID, + pub_id: UUID, + num_buffers: int, + shm: SHMContext | None, + graph_address: AddressType | None = None + ) -> None: + self.id = id + self.pub_id = pub_id + self.num_buffers = num_buffers + self.shm = shm + + self.cache_id = [None] * self.num_buffers + self.cache = [None] * self.num_buffers + self.backpressure = Backpressure(self.num_buffers) + self.sub_queues = dict() + self._graph_address = graph_address + + @classmethod + async def create( + cls, + pub_id: UUID, + graph_address: AddressType, + ) -> "_Channel": + graph_service = GraphService(graph_address) + + graph_reader, graph_writer = await graph_service.open_connection() + graph_writer.write(Command.CHANNEL.value) + graph_writer.write(encode_str(str(pub_id))) + + response = await graph_reader.read(1) + if response != Command.COMPLETE.value: + raise ValueError(f'failed to create channel {pub_id=}') + + id_str = await read_str(graph_reader) + pub_address = await Address.from_stream(graph_reader) + + reader, writer = await asyncio.open_connection(*pub_address) + + writer.write(Command.CHANNEL.value) + writer.write(encode_str(id_str)) + + shm = None + shm_name = await read_str(reader) + try: + shm = await graph_service.attach_shm(shm_name) + writer.write(Command.SHM_OK.value) + except (ValueError, OSError): + writer.write(Command.SHM_ATTACH_FAILED.value) + writer.write(uint64_to_bytes(os.getpid())) + + result = await reader.readexactly(1) + if result != Command.COMPLETE.value: + raise ValueError(f'failed to create channel {pub_id=}') + + num_buffers = await read_int(reader) + + chan = cls(UUID(id_str), pub_id, num_buffers, shm) + + chan._graph_task = asyncio.create_task( + chan._graph_connection(graph_reader, graph_writer) + ) + + chan._pub_writer = writer + chan._pub_task = asyncio.create_task( + chan._publisher_connection(reader) + ) + + return chan + + async def _graph_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + while True: + cmd = await reader.read(1) + + if not cmd: + break + + else: + logger.warning( + f"Channel {self.id} rx unknown command from GraphServer: {cmd}" + ) + except (ConnectionResetError, BrokenPipeError): + logger.debug(f"Channel {self.id} lost connection to graph server") + + finally: + await close_stream_writer(writer) + + async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: + try: + while True: + msg = await reader.read(1) + + if not msg: + break + + msg_id = await read_int(reader) + buf_idx = msg_id % self.num_buffers + + if msg == Command.TX_SHM.value: + shm_name = await read_str(reader) + + if self.shm is not None and self.shm.name != shm_name: + self.shm.close() + try: + self.shm = await GraphService(self._graph_address).attach_shm(shm_name) + except ValueError: + logger.info( + "Invalid SHM received from publisher; may be dead" + ) + raise + + elif msg == Command.TX_TCP.value: + buf_size = await read_int(reader) + obj_bytes = await reader.readexactly(buf_size) + + with MessageMarshal.obj_from_mem(memoryview(obj_bytes)) as obj: + self.cache[buf_idx] = obj + self.cache_id[buf_idx] = msg_id + + self._notify_subs(msg_id) + + except (ConnectionResetError, BrokenPipeError): + logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") + + finally: + await close_stream_writer(self._pub_writer) + # TODO: Remove this channel from CHANNELS ... ? maybe? + logger.debug(f"disconnected: channel:{self.id} -> pub:{id}") + + def _notify_subs(self, msg_id: int) -> None: + for sub_id, queue in self.sub_queues.items(): + self.backpressure.lease(sub_id, msg_id % self.num_buffers) + queue.put_nowait((self.pub_id, msg_id)) + + def put(self, msg_id: int, msg: typing.Any) -> None: + """put an object into cache (should only be used by Publishers)""" + buf_idx = msg_id % self.num_buffers + self.cache_id[buf_idx] = msg_id + self.cache[buf_idx] = msg + self._notify_subs(msg_id) + + @contextmanager + def get(self, msg_id: int, sub_id: UUID) -> typing.Generator[typing.Any, None, None]: + """get object from cache; if not in cache and shm provided -- get from shm""" + + buf_idx = msg_id % self.num_buffers + if self.cache_id[buf_idx] == msg_id: + yield self.cache[buf_idx] + + else: + if self.shm is None: + raise CacheMiss + + with self.shm.buffer(buf_idx, readonly=True) as mem: + if MessageMarshal.msg_id(mem) != msg_id: + raise CacheMiss + + with MessageMarshal.obj_from_mem(mem) as obj: + # Could deepcopy and put in cache here, but + # profiling indicates its faster to repeatedly + # reconstruct from memory for fanout <= 4 subs + # which I suspect will be majority of cases + yield obj + + self.backpressure.free(sub_id, buf_idx) + if self.backpressure.buffers[buf_idx].is_empty: + try: + ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) + self._pub_writer.write(ack) + except (BrokenPipeError, ConnectionResetError): + logger.debug(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") + + def subscribe(self, sub_id: UUID, sub_queue: asyncio.Queue[typing.Tuple[UUID, int]]) -> None: + self.sub_queues[sub_id] = sub_queue + + def unsubscribe(self, sub_id: UUID) -> None: + queue = self.sub_queues.get(sub_id, None) + + if queue is None: + return + + del self.sub_queues[sub_id] + + for _ in range(queue.qsize()): + ch_id, msg_id = queue.get_nowait() + if ch_id == self.id: + continue + queue.put_nowait((ch_id, msg_id)) + + self.backpressure.free(sub_id) + + def clear_cache(self): + self.cache_id = [None] * self.num_buffers + self.cache = [None] * self.num_buffers + + +class _ChannelManager: + + _registry: typing.Dict[Address, typing.Dict[UUID, _Channel]] + + def __init__(self): + default_address = Address.from_string(GRAPHSERVER_ADDR) + self._registry = {default_address: dict()} + + async def get(self, id: UUID, graph_address: AddressType | None = None) -> _Channel: + + if graph_address is None: + graph_address = Address.from_string(GRAPHSERVER_ADDR) + + elif not isinstance(graph_address, Address): + graph_address = Address(*graph_address) + + channels = self._registry.get(graph_address, dict()) + channel = channels.get(id, None) + if channel is None: + channel = await _Channel.create(id, graph_address) + channels[id] = channel + self._registry[graph_address] = channels + + return channel + + def unsubscribe_all(self, sub_id: UUID, graph_address: AddressType | None = None) -> None: + + if graph_address is None: + graph_address = Address.from_string(GRAPHSERVER_ADDR) + + elif not isinstance(graph_address, Address): + graph_address = Address(*graph_address) + + channels = self._registry.get(graph_address, None) + if channels is None: + return + + for channel in channels.values(): + channel.unsubscribe(sub_id) + + +CHANNELS = _ChannelManager() \ No newline at end of file diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index 537dc9e0..2266044c 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -38,13 +38,17 @@ def to_mem(cls, msg_id: int, obj: Any, mem: memoryview) -> None: if total_size >= len(mem): raise UndersizedMemory(req_size=total_size) + + cls._write(mem, header, buffers) - sidx = len(header) - mem[:sidx] = header[:] - for buf in buffers: - blen = len(buf) - mem[sidx : sidx + blen] = buf[:] - sidx += blen + @classmethod + def _write(cls, mem: memoryview, header: bytes, buffers: List[memoryview]): + sidx = len(header) + mem[:sidx] = header[:] + for buf in buffers: + blen = len(buf) + mem[sidx : sidx + blen] = buf[:] + sidx += blen @classmethod def _assert_initialized(cls, mem: memoryview) -> None: @@ -126,4 +130,7 @@ def copy_obj(cls, from_mem: memoryview, to_mem: memoryview) -> None: MessageMarshal.to_mem(msg_id, obj, to_mem) +# If some other byte-level representation is desired, you can just +# monkeypatch the module at runtime with a different Marhsal subclass +# TODO: This could also be done with environment variables MessageMarshal = Marshal diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index f08597b3..4ee9b531 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -2,6 +2,7 @@ import socket import typing import enum +import os from uuid import UUID from dataclasses import field, dataclass @@ -26,6 +27,11 @@ PUBLISHER_START_PORT_ENV = "EZMSG_PUBLISHER_PORT_START" PUBLISHER_START_PORT_DEFAULT = 25980 +GRAPHSERVER_ADDR = os.environ.get( + GRAPHSERVER_ADDR_ENV, + f"{DEFAULT_HOST}:{GRAPHSERVER_PORT_DEFAULT}" +) + class Address(typing.NamedTuple): host: str @@ -58,8 +64,6 @@ def __str__(self): class ClientInfo: id: UUID writer: asyncio.StreamWriter - pid: int - topic: str _pending: asyncio.Event = field(default_factory=asyncio.Event, init=False) @@ -83,12 +87,18 @@ async def sync_writer(self) -> typing.AsyncGenerator[asyncio.StreamWriter, None] @dataclass class PublisherInfo(ClientInfo): + topic: str address: Address @dataclass class SubscriberInfo(ClientInfo): - shm_access: bool = False + topic: str + + +@dataclass +class ChannelInfo(ClientInfo): + pub_id: UUID def uint64_to_bytes(i: int) -> bytes: @@ -166,6 +176,10 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SHM_ATTACH = enum.auto() SHUTDOWN = enum.auto() + + CHANNEL = enum.auto() + SHM_OK = enum.auto() + SHM_ATTACH_FAILED = enum.auto() def create_socket( diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index f39404cc..a1613aee 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -5,15 +5,17 @@ from uuid import UUID from contextlib import suppress +from dataclasses import dataclass from .backpressure import Backpressure from .shm import SHMContext from .graphserver import GraphService -from .messagecache import MessageCache, Cache +from .messagechannel import CHANNELS from .messagemarshal import MessageMarshal, UndersizedMemory from .netprotocol import ( Address, + AddressType, uint64_to_bytes, read_int, read_str, @@ -21,7 +23,7 @@ close_stream_writer, close_server, Command, - SubscriberInfo, + ChannelInfo, create_socket, DEFAULT_SHM_SIZE, PUBLISHER_START_PORT_ENV, @@ -36,6 +38,13 @@ BACKPRESSURE_REFRACTORY = 5.0 # sec +# Publisher needs a bit more information about connected channels +@dataclass +class PubChannelInfo(ChannelInfo): + pid: int + shm_ok: bool = False + + class Publisher: id: UUID pid: int @@ -44,8 +53,8 @@ class Publisher: _initialized: asyncio.Event _graph_task: "asyncio.Task[None]" _connection_task: "asyncio.Task[None]" - _subscribers: Dict[UUID, SubscriberInfo] - _subscriber_tasks: Dict[UUID, "asyncio.Task[None]"] + _channels: Dict[UUID, PubChannelInfo] + _channel_tasks: Dict[UUID, "asyncio.Task[None]"] _address: Address _backpressure: Backpressure _num_buffers: int @@ -55,7 +64,7 @@ class Publisher: _force_tcp: bool _last_backpressure_event: float - _graph_service: GraphService + _graph_address: AddressType | None @staticmethod def client_type() -> bytes: @@ -65,78 +74,94 @@ def client_type() -> bytes: async def create( cls, topic: str, - graph_service: GraphService, + graph_address: AddressType | None = None, host: Optional[str] = None, port: Optional[int] = None, buf_size: int = DEFAULT_SHM_SIZE, + num_buffers: int = 32, **kwargs, ) -> "Publisher": + # We have to fill in some parts of this class using async + pub = cls(topic, graph_address, num_buffers, **kwargs) + + graph_service = GraphService(graph_address) reader, writer = await graph_service.open_connection() - writer.write(Command.PUBLISH.value) - id = UUID(await read_str(reader)) - pub = cls(id, topic, graph_service, **kwargs) - writer.write(uint64_to_bytes(pub.pid) + encode_str(pub.topic)) - pub._shm = await graph_service.create_shm(pub._num_buffers, buf_size) + pub._shm = await graph_service.create_shm(num_buffers, buf_size) start_port = int( os.getenv(PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT) ) sock = create_socket(host, port, start_port=start_port) - server = await asyncio.start_server(pub._on_connection, sock=sock) - pub._address = Address(*sock.getsockname()) - pub._address.to_stream(writer) + address = Address(*sock.getsockname()) + + server = await asyncio.start_server(pub._channel_connect, sock=sock) + + writer.write(Command.PUBLISH.value) + writer.write(encode_str(topic)) + address.to_stream(writer) + + result = await reader.readexactly(1) + if result != Command.COMPLETE.value: + logger.warning(f'Could not create publisher {topic=}') + + pub.id = UUID(await read_str(reader)) + pub._graph_task = asyncio.create_task(pub._graph_connection(reader, writer)) async def serve() -> None: try: await server.serve_forever() - except asyncio.CancelledError: # FIXME: Poor form? - logger.debug("pubclient serve is Cancelled...") + except asyncio.CancelledError: + logger.debug("{pub.log_name} cancelled") finally: await close_server(server) - pub._connection_task = asyncio.create_task(serve(), name=f"pub_{str(id)}") + pub._connection_task = asyncio.create_task(serve(), name=pub.log_name) def on_done(_: asyncio.Future) -> None: - logger.debug("Closing pub server task.") + logger.debug("{pub.log_name} done") pub._connection_task.add_done_callback(on_done) - MessageCache[id] = Cache(pub._num_buffers) - await pub._initialized.wait() + + # Create the local Channel (it shouldn't already exist) + await CHANNELS.get(pub.id, graph_address) + return pub def __init__( self, - id: UUID, topic: str, - graph_service: GraphService, + graph_address: AddressType | None = None, num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, ) -> None: - self.id = id + """DO NOT USE this constructor to make a Publisher; use `create` instead""" self.pid = os.getpid() self.topic = topic self._msg_id = 0 - self._subscribers = dict() - self._subscriber_tasks = dict() + self._channels = dict() + self._channel_tasks = dict() self._running = asyncio.Event() if not start_paused: self._running.set() self._num_buffers = num_buffers self._backpressure = Backpressure(num_buffers) self._force_tcp = force_tcp - self._initialized = asyncio.Event() self._last_backpressure_event = -1 - self._graph_service = graph_service + self._graph_address = graph_address + + @property + def log_name(self) -> str: + return f"pub_{self.topic}{str(self.id)}" def close(self) -> None: self._graph_task.cancel() self._shm.close() self._connection_task.cancel() - for task in self._subscriber_tasks.values(): + for task in self._channel_tasks.values(): task.cancel() async def wait_closed(self) -> None: @@ -145,7 +170,7 @@ async def wait_closed(self) -> None: await self._graph_task with suppress(asyncio.CancelledError): await self._connection_task - for task in self._subscriber_tasks.values(): + for task in self._channel_tasks.values(): with suppress(asyncio.CancelledError): await task @@ -158,9 +183,6 @@ async def _graph_connection( if not cmd: break - elif cmd == Command.COMPLETE.value: - self._initialized.set() - elif cmd == Command.PAUSE.value: self._running.clear() @@ -184,35 +206,33 @@ async def _graph_connection( finally: await close_stream_writer(writer) - async def _on_connection( + async def _channel_connect( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - id_str = await read_str(reader) - id = UUID(id_str) - pid = await read_int(reader) - topic = await read_str(reader) - - # Subscriber determines if they have SHM access - writer.write(encode_str(self._shm.name)) - shm_access = bool(await read_int(reader)) - - writer.write( - encode_str(str(self.id)) - + uint64_to_bytes(self.pid) - + encode_str(self.topic) - + uint64_to_bytes(self._num_buffers) - ) - - info = SubscriberInfo(id, writer, pid, topic, shm_access) - coro = self._handle_subscriber(info, reader) - self._subscriber_tasks[id] = asyncio.create_task(coro) + """ Only messagechannel._Channel will connect here """ + + cmd = await reader.readexactly(1) + + if len(cmd) == 0: + return + + if cmd == Command.CHANNEL.value: + channel_id_str = await read_str(reader) + channel_id = UUID(channel_id_str) + writer.write(encode_str(self._shm.name)) + shm_ok = await reader.readexactly(1) == Command.SHM_OK.value + pid = await read_int(reader) + info = PubChannelInfo(channel_id, writer, self.id, pid, shm_ok) + coro = self._handle_channel(info, reader) + self._channel_tasks[channel_id] = asyncio.create_task(coro) + writer.write(Command.COMPLETE.value + uint64_to_bytes(self._num_buffers)) await writer.drain() - async def _handle_subscriber( - self, info: SubscriberInfo, reader: asyncio.StreamReader + async def _handle_channel( + self, info: PubChannelInfo, reader: asyncio.StreamReader ) -> None: - self._subscribers[info.id] = info + self._channels[info.id] = info try: while True: @@ -226,12 +246,12 @@ async def _handle_subscriber( self._backpressure.free(info.id, msg_id % self._num_buffers) except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Publisher {self.id}: Subscriber {id} connection fail") + logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") finally: self._backpressure.free(info.id) - await close_stream_writer(self._subscribers[info.id].writer) - del self._subscribers[info.id] + await close_stream_writer(self._channels[info.id].writer) + del self._channels[info.id] async def sync(self) -> None: """Pause and drain backpressure""" @@ -261,58 +281,58 @@ async def broadcast(self, obj: Any) -> None: self._last_backpressure_event = time.time() await self._backpressure.wait(buf_idx) - MessageCache[self.id].put(self._msg_id, obj) - - for sub in list(self._subscribers.values()): - if not self._force_tcp and sub.shm_access: - if sub.pid == self.pid: - sub.writer.write(Command.TX_LOCAL.value + msg_id_bytes) - else: - try: - # Push cache to shm (if not already there) - MessageCache[self.id].push(self._msg_id, self._shm) + # Get local channel and put variable there for local tx + channel = await CHANNELS.get(self.id, self._graph_address) + channel.put(self._msg_id, obj) - except UndersizedMemory as e: - new_shm = await self._graph_service.create_shm( - self._num_buffers, e.req_size * 2 - ) + if any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): + with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): + total_size_bytes = uint64_to_bytes(total_size) + if any(ch.pid != self.pid and ch.shm_ok for ch in self._channels.values()): + if self._shm.buf_size < total_size: + new_shm = await GraphService(self._graph_address).create_shm(self._num_buffers, total_size * 2) + for i in range(self._num_buffers): with self._shm.buffer(i, readonly=True) as from_buf: with new_shm.buffer(i) as to_buf: MessageMarshal.copy_obj(from_buf, to_buf) - + self._shm.close() self._shm = new_shm - MessageCache[self.id].push(self._msg_id, self._shm) - sub.writer.write( - Command.TX_SHM.value + \ - msg_id_bytes + \ - encode_str(self._shm.name) - ) + with self._shm.buffer(buf_idx) as mem: + MessageMarshal._write(mem, header, buffers) - else: - with MessageMarshal.serialize(self._msg_id, obj) as ser_obj: - total_size, header, buffers = ser_obj - total_size_bytes = uint64_to_bytes(total_size) - - sub.writer.write( - Command.TX_TCP.value + \ - msg_id_bytes + \ - total_size_bytes + \ - header + \ - b''.join([buffer for buffer in buffers]) - ) + for channel in self._channels.values(): - try: - await sub.writer.drain() - self._backpressure.lease(sub.id, buf_idx) - - except (ConnectionResetError, BrokenPipeError): - logger.debug( - f"Publisher {self.id}: Subscriber {sub.id} connection fail" - ) - continue + if self.pid == channel.pid and channel.shm_ok: + continue # Local transmission handled by channel.put + + elif self.pid != channel.pid and channel.shm_ok: + channel.writer.write( + Command.TX_SHM.value + \ + msg_id_bytes + \ + encode_str(self._shm.name) + ) + + elif self.pid != channel.pid and not channel.shm_ok: + channel.writer.write( + Command.TX_TCP.value + \ + msg_id_bytes + \ + total_size_bytes + \ + header + \ + b''.join([buffer for buffer in buffers]) + ) + + try: + await channel.writer.drain() + self._backpressure.lease(channel.id, buf_idx) + + except (ConnectionResetError, BrokenPipeError): + logger.debug( + f"Publisher {self.id}: Channel {channel.id} connection fail" + ) + continue self._msg_id += 1 diff --git a/src/ezmsg/core/server.py b/src/ezmsg/core/server.py index 42020483..c1c1461c 100644 --- a/src/ezmsg/core/server.py +++ b/src/ezmsg/core/server.py @@ -10,6 +10,7 @@ from .netprotocol import ( Address, AddressType, + DEFAULT_HOST, close_server, close_stream_writer, create_socket, @@ -98,6 +99,7 @@ async def api( T = typing.TypeVar("T", bound=ThreadedAsyncServer) +# TODO: Get rid of Service Manager and drop this in its own file class ServiceManager(typing.Generic[T]): _address: typing.Optional[Address] = None @@ -139,7 +141,7 @@ def address(self) -> Address: @classmethod def default_address(cls) -> Address: - address_str = os.environ.get(cls.ADDR_ENV, f"127.0.0.1:{cls.PORT_DEFAULT}") + address_str = os.environ.get(cls.ADDR_ENV, f"{DEFAULT_HOST}:{cls.PORT_DEFAULT}") return Address.from_string(address_str) def create_server(self) -> T: diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 41a3fdb2..b7951973 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -11,6 +11,7 @@ from .netprotocol import ( close_stream_writer, bytes_to_uint, + uint64_to_bytes, ) logger = logging.getLogger("ezmsg") @@ -74,7 +75,7 @@ def __init__(self, name: str) -> None: with self._shm.buf[8:16] as buf_size_mem: self.buf_size = bytes_to_uint(buf_size_mem) - buf_starts = [buf_idx * self.buf_size for buf_idx in range(self.num_buffers)] + buf_starts = [buf_idx * (self.buf_size + 16) for buf_idx in range(self.num_buffers)] buf_stops = [buf_start + self.buf_size for buf_start in buf_starts] buf_data_block_starts = [buf_start + 16 for buf_start in buf_starts] @@ -82,9 +83,8 @@ def __init__(self, name: str) -> None: slice(*seg) for seg in zip(buf_data_block_starts, buf_stops) ] - # TODO: Get rid of underscore and rename to attach @classmethod - def _create( + def attach( cls, shm_name: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "SHMContext": context = cls(shm_name) @@ -132,7 +132,7 @@ async def close_shm(self) -> None: return except BufferError: logger.debug("BufferError caught... Sleeping.") - await asyncio.sleep(1) + await asyncio.sleep(0.1) async def wait_closed(self) -> None: with suppress(asyncio.CancelledError): @@ -154,10 +154,15 @@ class SHMInfo: @classmethod def create( - cls, uuid: UUID, num_buffers: int, buf_size: int + cls, num_buffers: int, buf_size: int ) -> "SHMInfo": - ... - + buf_size += 16 * num_buffers # Repeated header info makes this a bit bigger + shm = SharedMemory(size=num_buffers * buf_size, create=True) + shm.buf[:] = b"0" * len(shm.buf) # Guarantee zeros + shm.buf[0:8] = uint64_to_bytes(num_buffers) + shm.buf[8:16] = uint64_to_bytes(buf_size) + return cls(shm) + def lease( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "asyncio.Task[None]": diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 50e1aa0c..960bccc2 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -1,4 +1,3 @@ -import os import asyncio import logging import typing @@ -8,21 +7,14 @@ from copy import deepcopy from .graphserver import GraphService -from .shm import SHMContext -from .messagecache import MessageCache, Cache -from .messagemarshal import MessageMarshal +from .messagechannel import CHANNELS from .netprotocol import ( - Address, - UINT64_SIZE, - uint64_to_bytes, - bytes_to_uint, - read_int, + AddressType, read_str, encode_str, close_stream_writer, Command, - PublisherInfo, ) @@ -31,62 +23,55 @@ class Subscriber: id: UUID - pid: int topic: str - _initialized: asyncio.Event + _graph_address: AddressType | None _graph_task: "asyncio.Task[None]" - _publishers: typing.Dict[UUID, PublisherInfo] - _publisher_tasks: typing.Dict[UUID, "asyncio.Task[None]"] - _shms: typing.Dict[UUID, SHMContext] + _cur_pubs: typing.Set[UUID] _incoming: "asyncio.Queue[typing.Tuple[UUID, int]]" - _graph_service: GraphService - @classmethod async def create( - cls, topic: str, graph_service: GraphService, **kwargs + cls, + topic: str, + graph_address: AddressType | None, + **kwargs ) -> "Subscriber": - reader, writer = await graph_service.open_connection() + reader, writer = await GraphService(graph_address).open_connection() writer.write(Command.SUBSCRIBE.value) + writer.write(encode_str(topic)) + + result = await reader.readexactly(1) + if result != Command.COMPLETE.value: + logger.warning(f'Could not create subscriber {topic=}') + id_str = await read_str(reader) - sub = cls(UUID(id_str), topic, graph_service, **kwargs) - writer.write(uint64_to_bytes(sub.pid)) - writer.write(encode_str(sub.topic)) + + sub = cls(UUID(id_str), topic, graph_address, **kwargs) sub._graph_task = asyncio.create_task(sub._graph_connection(reader, writer)) - await sub._initialized.wait() + return sub - def __init__( - self, id: UUID, topic: str, graph_service: GraphService, **kwargs + def __init__(self, + id: UUID, + topic: str, + graph_address: AddressType | None, + **kwargs ) -> None: + """DO NOT USE this constructor, use Subscriber.create instead""" self.id = id - self.pid = os.getpid() self.topic = topic + self._graph_address = graph_address - self._publishers = dict() - self._publisher_tasks = dict() - self._shms = dict() + self._cur_pubs = set() self._incoming = asyncio.Queue() - self._initialized = asyncio.Event() - - self._graph_service = graph_service def close(self) -> None: self._graph_task.cancel() - for task in self._publisher_tasks.values(): - task.cancel() - for shm in self._shms.values(): - shm.close() async def wait_closed(self) -> None: with suppress(asyncio.CancelledError): await self._graph_task - for task in self._publisher_tasks.values(): - with suppress(asyncio.CancelledError): - await task - for shm in self._shms.values(): - await shm.wait_closed() async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter @@ -97,33 +82,17 @@ async def _graph_connection( if not cmd: break - elif cmd == Command.COMPLETE.value: - self._initialized.set() - elif cmd == Command.UPDATE.value: - pub_addresses: typing.Dict[UUID, Address] = {} - connections = await read_str(reader) - connections = connections.strip(",") - if len(connections): - for connection in connections.split(","): - pub_id, pub_address = connection.split("@") - pub_id = UUID(pub_id) - pub_address = Address.from_string(pub_address) - pub_addresses[pub_id] = pub_address - - for id in set(pub_addresses.keys() - self._publishers.keys()): - connected = asyncio.Event() - coro = self._handle_publisher(id, pub_addresses[id], connected) - task_name = f"sub{self.id}:_handle_publisher({id})" - self._publisher_tasks[id] = asyncio.create_task( - coro, name=task_name - ) - await connected.wait() - - for id in set(self._publishers.keys() - pub_addresses.keys()): - self._publisher_tasks[id].cancel() - with suppress(asyncio.CancelledError): - await self._publisher_tasks[id] + update = await read_str(reader) + pub_ids = set([UUID(id) for id in update.split(',')]) + + for pub_id in set(pub_ids - self._cur_pubs): + channel = await CHANNELS.get(pub_id, self._graph_address) + channel.subscribe(self.id, self._incoming) + + for pub_id in set(self._cur_pubs - pub_ids): + channel = await CHANNELS.get(pub_id, self._graph_address) + channel.unsubscribe(self.id) writer.write(Command.COMPLETE.value) await writer.drain() @@ -137,87 +106,9 @@ async def _graph_connection( logger.debug(f"Subscriber {self.id} lost connection to graph server") finally: + CHANNELS.unsubscribe_all(self.id, self._graph_address) # good idea? await close_stream_writer(writer) - async def _handle_publisher( - self, id: UUID, address: Address, connected: asyncio.Event - ) -> None: - reader, writer = await asyncio.open_connection(*address) - writer.write(encode_str(str(self.id))) - writer.write(uint64_to_bytes(self.pid)) - writer.write(encode_str(self.topic)) - await writer.drain() - - # Pub replies with current shm name - # We attempt to attach and let pub know if we have SHM access - shm_name = await read_str(reader) - try: - (await self._graph_service.attach_shm(shm_name)).close() - writer.write(uint64_to_bytes(1)) - except (ValueError, OSError): - writer.write(uint64_to_bytes(0)) - await writer.drain() - - pub_id_str = await read_str(reader) - pub_pid = await read_int(reader) - pub_topic = await read_str(reader) - num_buffers = await read_int(reader) - - if id != UUID(pub_id_str): - raise ValueError("Unexpected Publisher ID") - - # NOTE: Not thread safe - if id not in MessageCache: - MessageCache[id] = Cache(num_buffers) - - self._publishers[id] = PublisherInfo(id, writer, pub_pid, pub_topic, address) - - connected.set() - - try: - while True: - msg = await reader.read(1) - if not msg: - break - - msg_id_bytes = await reader.read(UINT64_SIZE) - msg_id = bytes_to_uint(msg_id_bytes) - - if msg == Command.TX_SHM.value: - shm_name = await read_str(reader) - - if id not in self._shms or self._shms[id].name != shm_name: - if id in self._shms: - self._shms[id].close() - try: - self._shms[id] = await self._graph_service.attach_shm( - shm_name - ) - except ValueError: - logger.info( - "Invalid SHM received from publisher; may be dead" - ) - raise - - # FIXME: TCP connections could be more efficient. - # https://github.com/iscoe/ezmsg/issues/5 - elif msg == Command.TX_TCP.value: - buf_size = await read_int(reader) - obj_bytes = await reader.readexactly(buf_size) - - with MessageMarshal.obj_from_mem(memoryview(obj_bytes)) as obj: - MessageCache[id].put(msg_id, obj) - - self._incoming.put_nowait((id, msg_id)) - - except (ConnectionResetError, BrokenPipeError): - logger.debug(f"connection fail: sub:{self.id} -> pub:{id}") - - finally: - await close_stream_writer(self._publishers[id].writer) - del self._publishers[id] - logger.debug(f"disconnected: sub:{self.id} -> pub:{id}") - async def recv(self) -> typing.Any: out_msg = None async with self.recv_zero_copy() as msg: @@ -227,18 +118,9 @@ async def recv(self) -> typing.Any: @asynccontextmanager async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: id, msg_id = await self._incoming.get() - msg_id_bytes = uint64_to_bytes(msg_id) - try: - shm = self._shms.get(id, None) - with MessageCache[id].get(msg_id, shm) as msg: - yield msg + channel = await CHANNELS.get(id, self._graph_address) + with channel.get(msg_id, self.id) as msg: + yield msg - finally: - if id in self._publishers: - try: - ack = Command.RX_ACK.value + msg_id_bytes - self._publishers[id].writer.write(ack) - await self._publishers[id].writer.drain() - except (BrokenPipeError, ConnectionResetError): - logger.debug(f"ack fail: sub:{self.id} -> pub:{id}") + From bd53ad739a8164b05c57c1a8c45cc804d3514b75 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 20 Aug 2025 14:11:22 -0400 Subject: [PATCH 24/88] bugfixes --- src/ezmsg/core/graphserver.py | 162 +++++++++++++++++++--------------- src/ezmsg/core/netprotocol.py | 1 - src/ezmsg/core/pubclient.py | 2 +- 3 files changed, 91 insertions(+), 74 deletions(-) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 125d8569..59cd929a 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -73,6 +73,10 @@ async def api( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: try: + # Always start communications by telling the client our ezmsg version + # This helps us get ahead of surprisingly common situations where + # the graph server and the graph clients are running different versions + # of ezmsg which could result in borked comms. writer.write(encode_str(__version__)) await writer.drain() @@ -117,80 +121,79 @@ async def api( with suppress(asyncio.CancelledError): await shm_info.lease(reader, writer) + elif req == Command.CHANNEL.value: + channel_id = uuid1() + pub_id_str = await read_str(reader) + pub_id = UUID(pub_id_str) + + pub_addr = None + try: + pub_info = self.clients[pub_id] + if isinstance(pub_info, PublisherInfo): + # assemble an address the channel should be able to resolve + port = pub_info.address.port + iface = pub_info.writer.transport.get_extra_info("peername")[0] + pub_addr = Address(iface, port) + else: + logger.warning(f"Connecting channel requested {type(pub_info)}") + + except KeyError: + logger.warning(f"Connecting channel requested non-existent publisher {pub_id=}") + + if pub_addr is not None: + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(channel_id))) + pub_addr.to_stream(writer) + + info = ChannelInfo(channel_id, writer, pub_id) + self.clients[channel_id] = info + self._client_tasks[channel_id] = asyncio.create_task( + self._handle_client(channel_id, reader, writer) + ) + else: + # Error, drop connection + await close_stream_writer(writer) + + # Created a client; we must return early to avoid closing writer + return + else: # We only want to handle one command at a time async with self._command_lock: if req in [ Command.SUBSCRIBE.value, Command.PUBLISH.value, - Command.CHANNEL.value ]: - id = uuid1() - - if req in [Command.SUBSCRIBE.value, Command.PUBLISH.value]: - topic = await read_str(reader) - - if req == Command.SUBSCRIBE.value: - info = SubscriberInfo(id, writer, topic) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) - ) - - writer.write(Command.COMPLETE.value) - writer.write(encode_str(str(id))) - - await self._notify_subscriber(info) - - elif req == Command.PUBLISH.value: - address = await Address.from_stream(reader) - info = PublisherInfo(id, writer, topic, address) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) - ) - - writer.write(Command.COMPLETE.value) - writer.write(encode_str(str(id))) - - for sub in self._downstream_subs(info.topic): - await self._notify_subscriber(sub) - - - elif req == Command.CHANNEL.value: - pub_id_str = await read_str(reader) - pub_id = UUID(pub_id_str) - - pub_addr = None - try: - pub_info = self.clients[pub_id] - if isinstance(pub_info, PublisherInfo): - # assemble an address the channel should be able to resolve - port = pub_info.address.port - iface = pub_info.writer.transport.get_extra_info("peername")[0] - pub_addr = Address(iface, port) - else: - logger.warning(f"Connecting channel requested {type(pub_info)}") - - except KeyError: - logger.warning(f"Connecting channel requested non-existent publisher {pub_id=}") - - if pub_addr is not None: - writer.write(Command.COMPLETE.value) - writer.write(encode_str(str(id))) - pub_addr.to_stream(writer) - - info = ChannelInfo(id, writer, pub_id) - self.clients[id] = info - self._client_tasks[id] = asyncio.create_task( - self._handle_client(id, reader, writer) - ) - else: - # Error, drop connection - await close_stream_writer(writer) - return + client_id = uuid1() + topic = await read_str(reader) - await writer.drain() + if req == Command.SUBSCRIBE.value: + info = SubscriberInfo(client_id, writer, topic) + self.clients[client_id] = info + self._client_tasks[client_id] = asyncio.create_task( + self._handle_client(client_id, reader, writer) + ) + + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(client_id))) + + await self._notify_subscriber(info) + + elif req == Command.PUBLISH.value: + address = await Address.from_stream(reader) + info = PublisherInfo(client_id, writer, topic, address) + self.clients[client_id] = info + self._client_tasks[client_id] = asyncio.create_task( + self._handle_client(client_id, reader, writer) + ) + + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(client_id))) + + for sub in self._downstream_subs(info.topic): + await self._notify_subscriber(sub) + + # Created a client, must return early to avoid closing writer return elif req in [Command.CONNECT.value, Command.DISCONNECT.value]: @@ -248,9 +251,21 @@ async def api( await close_stream_writer(writer) async def _handle_client( - self, id: UUID, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + self, client_id: UUID, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - logger.debug(f"Graph Server: Client connected: {id}") + """ + The lifecycle of a graph client is tied to the lifecycle of this TCP connection. + We always attempt to read from the reader, and if the connection is ever closed + from the other side, reader.read(1) returns empty, we break the loop, and delete + the client from our list of connected clients. Because we're always reading + from the client within this task, we have to handle all client responses within + this task. + + NOTE: _notify_subscriber requires a COMPLETE once the subscriber has reconfigured + its connections. ClientInfo.sync_writer gives us the mechanism by which to block + _notify_subscriber until the COMPLETE is received in this context. + """ + logger.debug(f"Graph Server: Client connected: {client_id}") try: while True: @@ -260,24 +275,27 @@ async def _handle_client( break if req == Command.COMPLETE.value: - self.clients[id].set_sync() + self.clients[client_id].set_sync() except (ConnectionResetError, BrokenPipeError) as e: - logger.debug(f"Client {id} disconnected from GraphServer: {e}") + logger.debug(f"Client {client_id} disconnected from GraphServer: {e}") finally: - self.clients[id].set_sync() - del self.clients[id] + self.clients[client_id].set_sync() + del self.clients[client_id] await close_stream_writer(writer) async def _notify_subscriber(self, sub: SubscriberInfo) -> None: try: pub_ids = [str(pub.id) for pub in self._upstream_pubs(sub.topic)] + # Update requires us to read a 'COMPLETE' + # This cannot be done from this context async with sub.sync_writer() as writer: notify_str = ",".join(pub_ids) writer.write(Command.UPDATE.value) writer.write(encode_str(notify_str)) + except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"Failed to update Subscriber {sub.id}: {e}") diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 4ee9b531..cef6729e 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -78,7 +78,6 @@ async def sync_writer(self) -> typing.AsyncGenerator[asyncio.StreamWriter, None] await self._pending.wait() try: yield self.writer - await self.writer.drain() self._pending.clear() await self._pending.wait() finally: diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index a1613aee..2ed3283a 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -11,7 +11,7 @@ from .shm import SHMContext from .graphserver import GraphService from .messagechannel import CHANNELS -from .messagemarshal import MessageMarshal, UndersizedMemory +from .messagemarshal import MessageMarshal from .netprotocol import ( Address, From e8bd0aabb6987e6b8f92e9631425e96cfdcb1610 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 20 Aug 2025 15:13:20 -0400 Subject: [PATCH 25/88] test scripts --- src/ezmsg/core/pubclient.py | 15 ++--- tests/test_new_messaging.py | 97 ++++++++++++++++++++++++++++++ tests/test_reconstitute.py | 114 ++++++++++++++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 7 deletions(-) create mode 100644 tests/test_new_messaging.py create mode 100644 tests/test_reconstitute.py diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 2ed3283a..9482b1dd 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -282,14 +282,15 @@ async def broadcast(self, obj: Any) -> None: await self._backpressure.wait(buf_idx) # Get local channel and put variable there for local tx - channel = await CHANNELS.get(self.id, self._graph_address) - channel.put(self._msg_id, obj) + if not self._force_tcp: + channel = await CHANNELS.get(self.id, self._graph_address) + channel.put(self._msg_id, obj) - if any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): + if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): total_size_bytes = uint64_to_bytes(total_size) - if any(ch.pid != self.pid and ch.shm_ok for ch in self._channels.values()): + if not self._force_tcp and any(ch.pid != self.pid and ch.shm_ok for ch in self._channels.values()): if self._shm.buf_size < total_size: new_shm = await GraphService(self._graph_address).create_shm(self._num_buffers, total_size * 2) @@ -306,17 +307,17 @@ async def broadcast(self, obj: Any) -> None: for channel in self._channels.values(): - if self.pid == channel.pid and channel.shm_ok: + if not self._force_tcp and self.pid == channel.pid and channel.shm_ok: continue # Local transmission handled by channel.put - elif self.pid != channel.pid and channel.shm_ok: + elif not self._force_tcp and self.pid != channel.pid and channel.shm_ok: channel.writer.write( Command.TX_SHM.value + \ msg_id_bytes + \ encode_str(self._shm.name) ) - elif self.pid != channel.pid and not channel.shm_ok: + else: channel.writer.write( Command.TX_TCP.value + \ msg_id_bytes + \ diff --git a/tests/test_new_messaging.py b/tests/test_new_messaging.py new file mode 100644 index 00000000..be238a1d --- /dev/null +++ b/tests/test_new_messaging.py @@ -0,0 +1,97 @@ +import asyncio + +from ezmsg.core.graphserver import GraphServer, GraphService + +from ezmsg.core.subclient import Subscriber +from ezmsg.core.pubclient import Publisher + +# ADDR = ('127.0.0.1', 12345) +ADDR = ('0.0.0.0', 12345) +MAX_COUNT = 100 +TOPIC = '/TEST' + +async def handle_pub(pub: Publisher) -> None: + print('Publisher Task Launched') + + count = 0 + + while True: + await pub.broadcast(f'{count=}') + await asyncio.sleep(0.1) + count += 1 + if count >= MAX_COUNT: break + + print('Publisher Task Concluded') + +async def handle_sub(sub: Subscriber) -> None: + print('Subscriber Task Launched') + + rx_count = 0 + while True: + async with sub.recv_zero_copy() as msg: + print(msg) + rx_count += 1 + if rx_count >= MAX_COUNT: break + + print('Subscriber Task Concluded') + + +async def host(): + # Manually create a GraphServer + server = GraphServer() + server.start(ADDR) + + print(f'Created GraphServer @ {server.address}') + + # Create a graph_service that will interact with this GraphServer + graph_service = GraphService(ADDR) + await graph_service.ensure() + + test_pub = await Publisher.create(TOPIC, ADDR, host="0.0.0.0") + test_sub1 = await Subscriber.create(TOPIC, ADDR) + test_sub2 = await Subscriber.create(TOPIC, ADDR) + + pub_task = asyncio.Task(handle_pub(test_pub)) + sub_task_1 = asyncio.Task(handle_sub(test_sub1)) + sub_task_2 = asyncio.Task(handle_sub(test_sub2)) + + try: + await asyncio.wait([pub_task, sub_task_1, sub_task_2]) + + finally: + server.stop() + + print('Done') + + +async def attach_client(): + # Attach to a running GraphServer + graph_service = GraphService(ADDR) + await graph_service.ensure() + + print(f'Connected to GraphServer @ {graph_service.address}') + + sub = await Subscriber.create(TOPIC, ADDR) + + while True: + async with sub.recv_zero_copy() as msg: + print(msg) + + +if __name__ == '__main__': + from dataclasses import dataclass + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument('--attach', action = 'store_true', help = 'attach to running graph') + + @dataclass + class Args: + attach: bool + + args = Args(**vars(parser.parse_args())) + + if args.attach: + asyncio.run(attach_client()) + else: + asyncio.run(host()) \ No newline at end of file diff --git a/tests/test_reconstitute.py b/tests/test_reconstitute.py new file mode 100644 index 00000000..8b1e418f --- /dev/null +++ b/tests/test_reconstitute.py @@ -0,0 +1,114 @@ +import asyncio +from copy import deepcopy +import timeit +import matplotlib.pyplot as plt + +from ezmsg.core.graphserver import GraphServer, GraphService +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.core.messagemarshal import MessageMarshal + +import numpy as np + +ADDR = ('127.0.0.1', 12345) +ITERS = 10000 + +def setup_graph_and_shm(msg_size): + server = GraphServer() + server.start(ADDR) + graph_service = GraphService(ADDR) + loop = asyncio.get_event_loop() + loop.run_until_complete(graph_service.ensure()) + aa = AxisArray(data=np.random.normal(size=(msg_size,)), dims=['a']) + with MessageMarshal.serialize(0, aa) as (size, _, _): + msg_size_bytes = size + shm = loop.run_until_complete(graph_service.create_shm(num_buffers=32, buf_size=msg_size_bytes + 100)) + with shm.buffer(0) as mem: + MessageMarshal.to_mem(0, aa, mem) + return server, shm, msg_size_bytes + +def recon_with_deepcopy_func(shm): + with shm.buffer(0, readonly=True) as mem: + with MessageMarshal.obj_from_mem(mem) as obj: + deepcopy(obj) + +def recon_without_deepcopy_func(shm): + with shm.buffer(0, readonly=True) as mem: + with MessageMarshal.obj_from_mem(mem) as obj: + pass + + +def main(): + msg_sizes = [2**i for i in range(8, 25, 2)] # 2**8, 2**10, ..., 2**24 + fanouts = [1, 2, 4, 8, 16, 32] + results = {} + msg_size_bytes_list = [] + + for msg_size in msg_sizes: + print(f"\nTesting MSG_SIZE={msg_size}") + server, shm, msg_size_bytes = setup_graph_and_shm(msg_size) + msg_size_bytes_list.append(msg_size_bytes) + + # Warm up + for _ in range(100): + recon_with_deepcopy_func(shm) + recon_without_deepcopy_func(shm) + + # Time a single reconstitution and a single deepcopy + recon_time = timeit.timeit( + stmt=lambda: recon_without_deepcopy_func(shm), + number=ITERS + ) / ITERS * 1e9 # ns + deepcopy_time = timeit.timeit( + stmt=lambda: recon_with_deepcopy_func(shm), + number=ITERS + ) / ITERS * 1e9 # ns + + print(f'Per reconstitution (ns): {recon_time}') + print(f'Per deepcopy (ns): {deepcopy_time}') + + # For each fanout, compute total time for both strategies + total_recon = [] + total_deepcopy = [] + for fanout in fanouts: + # Strategy 1: 1 reconstitution + 1 deepcopy, rest are free + total_deepcopy.append(recon_time + deepcopy_time) + # Strategy 2: reconstitute N times (no deepcopy) + total_recon.append(fanout * recon_time) + + results[msg_size_bytes] = { + 'fanouts': fanouts, + 'total_recon': total_recon, + 'total_deepcopy': total_deepcopy, + } + + server.stop() + + # Plot results + plt.figure(figsize=(12, 8)) + for msg_size_bytes in msg_size_bytes_list: + fanouts = results[msg_size_bytes]['fanouts'] + plt.plot( + fanouts, + results[msg_size_bytes]['total_recon'], + label=f'Recon N times ({msg_size_bytes} bytes)', + marker='o', linestyle='--', alpha=0.7 + ) + plt.plot( + fanouts, + results[msg_size_bytes]['total_deepcopy'], + label=f'Recon 1 + deepcopy N-1 ({msg_size_bytes} bytes)', + marker='x', linestyle='-', alpha=0.7 + ) + plt.xscale('log', base=2) + plt.yscale('log') + plt.xlabel('Fanout (number of consumers)') + plt.ylabel('Total time (ns)') + plt.title('Fanout Tradeoff: Reconstitute N times vs Reconstitute+Deepcopy') + plt.legend(fontsize='small', ncol=2) + plt.grid(True, which='both', ls='--', alpha=0.5) + plt.tight_layout() + plt.savefig('fanout_tradeoff.png') + print('Plot saved as fanout_tradeoff.png') + +if __name__ == '__main__': + main() \ No newline at end of file From 1d3948e1a4d412866f06cd0334a500c332f97188 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 20 Aug 2025 16:14:16 -0400 Subject: [PATCH 26/88] backpressure fixed --- src/ezmsg/core/pubclient.py | 5 +++-- tests/test_new_messaging.py | 32 +++++++++++++++++++------------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 9482b1dd..a87efd97 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -285,6 +285,7 @@ async def broadcast(self, obj: Any) -> None: if not self._force_tcp: channel = await CHANNELS.get(self.id, self._graph_address) channel.put(self._msg_id, obj) + self._backpressure.lease(channel.id, buf_idx) if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): @@ -307,10 +308,10 @@ async def broadcast(self, obj: Any) -> None: for channel in self._channels.values(): - if not self._force_tcp and self.pid == channel.pid and channel.shm_ok: + if (not self._force_tcp) and self.pid == channel.pid and channel.shm_ok: continue # Local transmission handled by channel.put - elif not self._force_tcp and self.pid != channel.pid and channel.shm_ok: + elif (not self._force_tcp) and self.pid != channel.pid and channel.shm_ok: channel.writer.write( Command.TX_SHM.value + \ msg_id_bytes + \ diff --git a/tests/test_new_messaging.py b/tests/test_new_messaging.py index be238a1d..9dec127b 100644 --- a/tests/test_new_messaging.py +++ b/tests/test_new_messaging.py @@ -5,8 +5,7 @@ from ezmsg.core.subclient import Subscriber from ezmsg.core.pubclient import Publisher -# ADDR = ('127.0.0.1', 12345) -ADDR = ('0.0.0.0', 12345) +PORT = 12345 MAX_COUNT = 100 TOPIC = '/TEST' @@ -29,27 +28,31 @@ async def handle_sub(sub: Subscriber) -> None: rx_count = 0 while True: async with sub.recv_zero_copy() as msg: + await asyncio.sleep(0.15) print(msg) + rx_count += 1 if rx_count >= MAX_COUNT: break print('Subscriber Task Concluded') -async def host(): +async def host(host: str = '127.0.0.1'): # Manually create a GraphServer server = GraphServer() - server.start(ADDR) + server.start((host, PORT)) print(f'Created GraphServer @ {server.address}') # Create a graph_service that will interact with this GraphServer - graph_service = GraphService(ADDR) + graph_service = GraphService((host, PORT)) await graph_service.ensure() - test_pub = await Publisher.create(TOPIC, ADDR, host="0.0.0.0") - test_sub1 = await Subscriber.create(TOPIC, ADDR) - test_sub2 = await Subscriber.create(TOPIC, ADDR) + test_pub = await Publisher.create(TOPIC, (host, PORT), host=host) + test_sub1 = await Subscriber.create(TOPIC, (host, PORT)) + test_sub2 = await Subscriber.create(TOPIC, (host, PORT)) + + await asyncio.sleep(1.0) pub_task = asyncio.Task(handle_pub(test_pub)) sub_task_1 = asyncio.Task(handle_sub(test_sub1)) @@ -64,17 +67,18 @@ async def host(): print('Done') -async def attach_client(): +async def attach_client(host: str = '127.0.0.1'): # Attach to a running GraphServer - graph_service = GraphService(ADDR) + graph_service = GraphService((host, PORT)) await graph_service.ensure() print(f'Connected to GraphServer @ {graph_service.address}') - sub = await Subscriber.create(TOPIC, ADDR) + sub = await Subscriber.create(TOPIC, (host, PORT)) while True: async with sub.recv_zero_copy() as msg: + await asyncio.sleep(1.0) print(msg) @@ -84,14 +88,16 @@ async def attach_client(): parser = ArgumentParser() parser.add_argument('--attach', action = 'store_true', help = 'attach to running graph') + parser.add_argument('--host', default = "0.0.0.0", help = 'hostname for graphserver') @dataclass class Args: attach: bool + host: str args = Args(**vars(parser.parse_args())) if args.attach: - asyncio.run(attach_client()) + asyncio.run(attach_client(host = args.host)) else: - asyncio.run(host()) \ No newline at end of file + asyncio.run(host(host = args.host)) \ No newline at end of file From 2fed32a740c9c0c051869d1a9cd1f065917f393c Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 21 Aug 2025 11:58:39 -0400 Subject: [PATCH 27/88] working on bugfixes --- src/ezmsg/core/backend.py | 67 +++++++++------- src/ezmsg/core/backendprocess.py | 12 +-- src/ezmsg/core/graphcontext.py | 38 +++++---- src/ezmsg/core/graphserver.py | 10 +-- src/ezmsg/core/messagechannel.py | 128 ++++++++++++++++++++----------- src/ezmsg/core/pubclient.py | 18 ++--- src/ezmsg/core/subclient.py | 30 +++++--- tests/test_new_messaging.py | 45 ++++++++--- tests/test_reconstitute.py | 57 +++++++------- 9 files changed, 247 insertions(+), 158 deletions(-) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index ffedc76a..62e3d974 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -31,47 +31,56 @@ class ExecutionContext: - processes: typing.List[BackendProcess] + _process_units: typing.List[typing.List[Unit]] + _processes: typing.List[BackendProcess] | None term_ev: EventType start_barrier: BarrierType connections: typing.List[typing.Tuple[str, str]] def __init__( self, - processes: typing.List[typing.List[Unit]], - graph_service: GraphService, + process_units: typing.List[typing.List[Unit]], connections: typing.List[typing.Tuple[str, str]] = [], - backend_process: typing.Type[BackendProcess] = DefaultBackendProcess, ) -> None: - if not processes: - raise ValueError("Cannot create an execution context for zero processes") - self.connections = connections + self._process_units = process_units + self._processes = None self.term_ev = Event() - self.start_barrier = Barrier(len(processes)) - self.stop_barrier = Barrier(len(processes)) + self.start_barrier = Barrier(len(process_units)) + self.stop_barrier = Barrier(len(process_units)) - self.processes = [ + def create_processes( + self, + graph_address: AddressType | None, + backend_process: typing.Type[BackendProcess] = DefaultBackendProcess, + ) -> None: + + self._processes = [ backend_process( process_units, self.term_ev, self.start_barrier, self.stop_barrier, - graph_service, + graph_address, ) - for process_units in processes + for process_units in self._process_units ] + @property + def processes(self) -> typing.List[BackendProcess]: + if self._processes is None: + raise ValueError("ExecutionContext has not initialized processes") + else: + return self._processes + @classmethod def setup( cls, components: typing.Mapping[str, Component], - graph_service: GraphService, root_name: typing.Optional[str] = None, connections: typing.Optional[NetworkDefinition] = None, process_components: typing.Optional[typing.Collection[Component]] = None, - backend_process: typing.Type[BackendProcess] = DefaultBackendProcess, force_single_process: bool = False, ) -> typing.Optional["ExecutionContext"]: graph_connections: typing.List[typing.Tuple[str, str]] = [] @@ -132,15 +141,13 @@ def configure_collections(comp: Component): if force_single_process: processes = [[u for pu in processes for u in pu]] - try: - return cls( - processes, - graph_service, - graph_connections, - backend_process, - ) - except ValueError: + if not processes: return None + + return cls( + processes, + graph_connections, + ) def run_system( @@ -173,22 +180,21 @@ def run( components: represents the nodes in the directed acyclic graph. It is a dictionary which contains the ``Components`` to be run mapped to string names. On initialization, ``ezmsg`` will call ``initialize()`` for each :obj:`Unit` and ``configure()`` for each :obj:`Collection`, if defined. - root_name: + root_name: gives a new root name to all nodes in this graph (e.g. [root_name]/COLLECTION/UNIT) connections: represents the edges is a ``NetworkDefinition`` which connects ``OutputStreams`` to ``InputStreams``. On initialization, ``ezmsg`` will create a directed acyclic graph using the contents of this parameter. process_components: a list of ``Components`` which should live in their own process. backend_process: is currently under development. graph_address: the hostname and port of the graph server which ``ezmsg`` should connect to. - If not defined, ``ezmsg`` will start a new graph server at 127.0.0.1:25978. + If not defined, ``ezmsg`` will try 127.0.0.1:25978, or fallback to a new graph server on a random port force_single_process: run all ``Components`` in one process. This is necessary when running ``ezmsg`` in a notebook. - components_kwargs: + components_kwargs: see components """ os.environ["EZMSG_PROFILER"] = profiler_log_name or "ezprofiler.log" # FIXME: This function is the last major re-implementation needed to make this # codebase more maintainable. - graph_service = GraphService(graph_address) if components is not None and isinstance(components, Component): components = {"SYSTEM": components} @@ -202,11 +208,9 @@ def run( with new_threaded_event_loop() as loop: execution_context = ExecutionContext.setup( components, - graph_service, root_name, connections, process_components, - backend_process, force_single_process, ) @@ -215,7 +219,7 @@ def run( # FIXME: When done this way, we don't exit the graph_context on exception async def create_graph_context() -> GraphContext: - return await GraphContext(graph_service).__aenter__() + return await GraphContext(graph_address).__aenter__() # FIXME: This sort of stuff should all be done in a separate async function... # Done this way, its ugly as hell and opens us up to a lot of issues with @@ -224,6 +228,11 @@ async def create_graph_context() -> GraphContext: create_graph_context(), loop ).result() + execution_context.create_processes( + graph_address=graph_context.graph_address, + backend_process=backend_process + ) + async def cleanup_graph() -> None: await graph_context.__aexit__(None, None, None) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 98a46f75..9c535a34 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -6,6 +6,7 @@ import traceback import threading +from abc import abstractmethod from collections import defaultdict from functools import wraps, partial from copy import deepcopy @@ -22,8 +23,7 @@ from .graphserver import GraphService from .pubclient import Publisher from .subclient import Subscriber -from .messagechannel import CHANNELS -from abc import abstractmethod +from .netprotocol import AddressType from typing import ( List, @@ -63,7 +63,7 @@ class BackendProcess(Process): term_ev: EventType start_barrier: BarrierType stop_barrier: BarrierType - graph_service: GraphService + graph_address: AddressType | None def __init__( self, @@ -71,14 +71,14 @@ def __init__( term_ev: EventType, start_barrier: BarrierType, stop_barrier: BarrierType, - graph_service: GraphService, + graph_address: AddressType | None, ) -> None: super().__init__() self.units = units self.term_ev = term_ev self.start_barrier = start_barrier self.stop_barrier = stop_barrier - self.graph_service = graph_service + self.graph_address = graph_address self.task_finished_ev: Optional[threading.Event] = None def run(self) -> None: @@ -99,7 +99,7 @@ class DefaultBackendProcess(BackendProcess): def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None - context = GraphContext(self.graph_service) + context = GraphContext(self.graph_address) coro_callables: Dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() try: diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index fa18cdcc..4dc91e2d 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -2,6 +2,7 @@ import logging import typing +from .netprotocol import AddressType, Address from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber @@ -23,49 +24,58 @@ class GraphContext: _clients: typing.Set[typing.Union[Publisher, Subscriber]] _edges: typing.Set[typing.Tuple[str, str]] - _graph_service: GraphService + _graph_address: AddressType | None _graph_server: typing.Optional[GraphServer] def __init__( self, - graph_service: typing.Optional[GraphService] = None, + graph_address: AddressType | None = None, ) -> None: self._clients = set() self._edges = set() - self._graph_service = ( - graph_service if graph_service is not None else GraphService() - ) + self._graph_address = graph_address self._graph_server = None + @property + def graph_address(self) -> AddressType | None: + if self._graph_server is not None: + return self._graph_server.address + else: + return self._graph_address + async def publisher(self, topic: str, **kwargs) -> Publisher: - pub = await Publisher.create(topic, self._graph_service, **kwargs) + pub = await Publisher.create( + topic, self.graph_address, **kwargs + ) self._clients.add(pub) return pub async def subscriber(self, topic: str, **kwargs) -> Subscriber: - sub = await Subscriber.create(topic, self._graph_service, **kwargs) + sub = await Subscriber.create( + topic, self.graph_address, **kwargs + ) self._clients.add(sub) return sub async def connect(self, from_topic: str, to_topic: str) -> None: - await self._graph_service.connect(from_topic, to_topic) + await GraphService(self.graph_address).connect(from_topic, to_topic) self._edges.add((from_topic, to_topic)) async def disconnect(self, from_topic: str, to_topic: str) -> None: - await self._graph_service.disconnect(from_topic, to_topic) + await GraphService(self.graph_address).disconnect(from_topic, to_topic) self._edges.discard((from_topic, to_topic)) async def sync(self, timeout: typing.Optional[float] = None) -> None: - await self._graph_service.sync(timeout) + await GraphService(self.graph_address).sync(timeout) async def pause(self) -> None: - await self._graph_service.pause() + await GraphService(self.graph_address).pause() async def resume(self) -> None: - await self._graph_service.resume() + await GraphService(self.graph_address).resume() async def _ensure_servers(self) -> None: - self._graph_server = await self._graph_service.ensure() + self._graph_server = await GraphService(self.graph_address).ensure() async def _shutdown_servers(self) -> None: if self._graph_server is not None: @@ -96,6 +106,6 @@ async def revert(self) -> None: for edge in self._edges: try: - await self._graph_service.disconnect(*edge) + await GraphService(self.graph_address).disconnect(*edge) except (ConnectionRefusedError, BrokenPipeError, ConnectionResetError) as e: logger.warn(f"Could not remove edge {edge} from GraphServer: {e}") diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 59cd929a..379f4978 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -81,11 +81,11 @@ async def api( await writer.drain() req = await reader.read(1) - - # Empty bytes object means EOF; Client disconnected - # This happens frequently when future clients are just pinging - # GraphServer to check if server is up + if not req: + # Empty bytes object means EOF; Client disconnected + # This happens frequently when future clients are just pinging + # GraphServer to check if server is up await close_stream_writer(writer) return @@ -392,7 +392,7 @@ async def dag(self, timeout: Optional[float] = None) -> DAG: writer.write(Command.DAG.value) await writer.drain() await asyncio.sleep(1.0) - await reader.readexactly(1) + await reader.read(1) dag_num_bytes = await read_int(reader) dag_bytes = await reader.readexactly(dag_num_bytes) dag: DAG = pickle.loads(dag_bytes) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index a60ec5d0..69935e9d 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -4,7 +4,7 @@ import logging from uuid import UUID -from contextlib import contextmanager +from contextlib import contextmanager, suppress from .shm import SHMContext from .messagemarshal import MessageMarshal @@ -26,6 +26,9 @@ logger = logging.getLogger("ezmsg") +NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] + + class CacheMiss(Exception): ... @@ -41,7 +44,7 @@ class _Channel: cache: typing.List[typing.Any] cache_id: typing.List[int | None] shm: SHMContext | None - sub_queues: typing.Dict[UUID, asyncio.Queue[typing.Tuple[UUID, int]]] + subs: typing.Dict[UUID, NotificationQueue] backpressure: Backpressure _graph_task: asyncio.Task[None] @@ -65,7 +68,7 @@ def __init__( self.cache_id = [None] * self.num_buffers self.cache = [None] * self.num_buffers self.backpressure = Backpressure(self.num_buffers) - self.sub_queues = dict() + self.subs = dict() self._graph_address = graph_address @classmethod @@ -101,7 +104,7 @@ async def create( writer.write(Command.SHM_ATTACH_FAILED.value) writer.write(uint64_to_bytes(os.getpid())) - result = await reader.readexactly(1) + result = await reader.read(1) if result != Command.COMPLETE.value: raise ValueError(f'failed to create channel {pub_id=}') @@ -120,6 +123,16 @@ async def create( return chan + def close(self) -> None: + self._graph_task.cancel() + self._pub_task.cancel() + + async def wait_closed(self) -> None: + with suppress(asyncio.CancelledError): + await self._graph_task + with suppress(asyncio.CancelledError): + await self._pub_task + async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: @@ -179,15 +192,14 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: finally: await close_stream_writer(self._pub_writer) - # TODO: Remove this channel from CHANNELS ... ? maybe? logger.debug(f"disconnected: channel:{self.id} -> pub:{id}") def _notify_subs(self, msg_id: int) -> None: - for sub_id, queue in self.sub_queues.items(): + for sub_id, queue in self.subs.items(): self.backpressure.lease(sub_id, msg_id % self.num_buffers) queue.put_nowait((self.pub_id, msg_id)) - def put(self, msg_id: int, msg: typing.Any) -> None: + def put_local(self, msg_id: int, msg: typing.Any) -> None: """put an object into cache (should only be used by Publishers)""" buf_idx = msg_id % self.num_buffers self.cache_id[buf_idx] = msg_id @@ -225,30 +237,37 @@ def get(self, msg_id: int, sub_id: UUID) -> typing.Generator[typing.Any, None, N except (BrokenPipeError, ConnectionResetError): logger.debug(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") - def subscribe(self, sub_id: UUID, sub_queue: asyncio.Queue[typing.Tuple[UUID, int]]) -> None: - self.sub_queues[sub_id] = sub_queue + def subscribe(self, sub_id: UUID, sub_queue: NotificationQueue) -> None: + self.subs[sub_id] = sub_queue def unsubscribe(self, sub_id: UUID) -> None: - queue = self.sub_queues.get(sub_id, None) - - if queue is None: - return - - del self.sub_queues[sub_id] + queue = self.subs[sub_id] for _ in range(queue.qsize()): - ch_id, msg_id = queue.get_nowait() - if ch_id == self.id: + pub_id, msg_id = queue.get_nowait() + if pub_id == self.pub_id: continue - queue.put_nowait((ch_id, msg_id)) + queue.put_nowait((pub_id, msg_id)) self.backpressure.free(sub_id) + del self.subs[sub_id] + def clear_cache(self): self.cache_id = [None] * self.num_buffers self.cache = [None] * self.num_buffers +def _ensure_address(address: AddressType | None) -> Address: + if address is None: + return Address.from_string(GRAPHSERVER_ADDR) + + elif not isinstance(address, Address): + return Address(*address) + + return address + + class _ChannelManager: _registry: typing.Dict[Address, typing.Dict[UUID, _Channel]] @@ -257,37 +276,58 @@ def __init__(self): default_address = Address.from_string(GRAPHSERVER_ADDR) self._registry = {default_address: dict()} - async def get(self, id: UUID, graph_address: AddressType | None = None) -> _Channel: - - if graph_address is None: - graph_address = Address.from_string(GRAPHSERVER_ADDR) - - elif not isinstance(graph_address, Address): - graph_address = Address(*graph_address) - + async def get( + self, + pub_id: UUID, + graph_address: AddressType | None = None, + create: bool = True + ) -> _Channel: + graph_address = _ensure_address(graph_address) channels = self._registry.get(graph_address, dict()) - channel = channels.get(id, None) - if channel is None: - channel = await _Channel.create(id, graph_address) - channels[id] = channel + channel = channels.get(pub_id, None) + if create and channel is None: + channel = await _Channel.create(pub_id, graph_address) + channels[pub_id] = channel self._registry[graph_address] = channels - + if channel is None: + raise ValueError("Channel does not exist") return channel - def unsubscribe_all(self, sub_id: UUID, graph_address: AddressType | None = None) -> None: - - if graph_address is None: - graph_address = Address.from_string(GRAPHSERVER_ADDR) - - elif not isinstance(graph_address, Address): - graph_address = Address(*graph_address) + async def subscribe( + self, + pub_id: UUID, + sub_id: UUID, + sub_queue: NotificationQueue, + graph_address: AddressType | None = None + ) -> _Channel: + channel = await self.get(pub_id, graph_address, create = True) + channel.subscribe(sub_id, sub_queue) + return channel + + async def unsubscribe_all( + self, + sub_id: UUID, + graph_address: AddressType | None = None + ) -> None: + graph_address = _ensure_address(graph_address) + channels = self._registry.get(graph_address, dict()) + for pub_id, channel in channels.items(): + if sub_id in channel.subs: + await self.unsubscribe(pub_id, sub_id, graph_address) - channels = self._registry.get(graph_address, None) - if channels is None: - return - - for channel in channels.values(): - channel.unsubscribe(sub_id) + async def unsubscribe( + self, + pub_id: UUID, + sub_id: UUID, + graph_address: AddressType | None = None + ) -> None: + graph_address = _ensure_address(graph_address) + channel = self._registry[graph_address][pub_id] + channel.unsubscribe(sub_id) + if len(channel.subs) == 0: + channel.close() + await channel.wait_closed() + logger.debug(f'closed channel {channel.id}: no subs') CHANNELS = _ChannelManager() \ No newline at end of file diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index a87efd97..681394f7 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -10,7 +10,7 @@ from .backpressure import Backpressure from .shm import SHMContext from .graphserver import GraphService -from .messagechannel import CHANNELS +from .messagechannel import CHANNELS, _Channel from .messagemarshal import MessageMarshal from .netprotocol import ( @@ -55,6 +55,7 @@ class Publisher: _connection_task: "asyncio.Task[None]" _channels: Dict[UUID, PubChannelInfo] _channel_tasks: Dict[UUID, "asyncio.Task[None]"] + _local_channel: _Channel _address: Address _backpressure: Backpressure _num_buffers: int @@ -100,7 +101,7 @@ async def create( writer.write(encode_str(topic)) address.to_stream(writer) - result = await reader.readexactly(1) + result = await reader.read(1) if result != Command.COMPLETE.value: logger.warning(f'Could not create publisher {topic=}') @@ -124,7 +125,7 @@ def on_done(_: asyncio.Future) -> None: pub._connection_task.add_done_callback(on_done) # Create the local Channel (it shouldn't already exist) - await CHANNELS.get(pub.id, graph_address) + pub._local_channel = await CHANNELS.get(pub.id, graph_address, create = True) return pub @@ -209,9 +210,7 @@ async def _graph_connection( async def _channel_connect( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - """ Only messagechannel._Channel will connect here """ - - cmd = await reader.readexactly(1) + cmd = await reader.read(1) if len(cmd) == 0: return @@ -220,7 +219,7 @@ async def _channel_connect( channel_id_str = await read_str(reader) channel_id = UUID(channel_id_str) writer.write(encode_str(self._shm.name)) - shm_ok = await reader.readexactly(1) == Command.SHM_OK.value + shm_ok = await reader.read(1) == Command.SHM_OK.value pid = await read_int(reader) info = PubChannelInfo(channel_id, writer, self.id, pid, shm_ok) coro = self._handle_channel(info, reader) @@ -283,9 +282,8 @@ async def broadcast(self, obj: Any) -> None: # Get local channel and put variable there for local tx if not self._force_tcp: - channel = await CHANNELS.get(self.id, self._graph_address) - channel.put(self._msg_id, obj) - self._backpressure.lease(channel.id, buf_idx) + self._local_channel.put_local(self._msg_id, obj) + self._backpressure.lease(self._local_channel.id, buf_idx) if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 960bccc2..fb299398 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -7,7 +7,7 @@ from copy import deepcopy from .graphserver import GraphService -from .messagechannel import CHANNELS +from .messagechannel import CHANNELS, NotificationQueue, _Channel from .netprotocol import ( AddressType, @@ -28,7 +28,12 @@ class Subscriber: _graph_address: AddressType | None _graph_task: "asyncio.Task[None]" _cur_pubs: typing.Set[UUID] - _incoming: "asyncio.Queue[typing.Tuple[UUID, int]]" + _incoming: NotificationQueue + + # This is an optimization to retain a local handle to channels + # so that dict lookup and wrapper contextmanager aren't in + # the hotpath - Griff + _channels: typing.Dict[UUID, _Channel] @classmethod async def create( @@ -41,7 +46,7 @@ async def create( writer.write(Command.SUBSCRIBE.value) writer.write(encode_str(topic)) - result = await reader.readexactly(1) + result = await reader.read(1) if result != Command.COMPLETE.value: logger.warning(f'Could not create subscriber {topic=}') @@ -65,6 +70,7 @@ def __init__(self, self._cur_pubs = set() self._incoming = asyncio.Queue() + self._channels = dict() def close(self) -> None: self._graph_task.cancel() @@ -79,20 +85,21 @@ async def _graph_connection( try: while True: cmd = await reader.read(1) + if not cmd: break elif cmd == Command.UPDATE.value: update = await read_str(reader) - pub_ids = set([UUID(id) for id in update.split(',')]) + pub_ids = set([UUID(id) for id in update.split(',')]) if update else set() for pub_id in set(pub_ids - self._cur_pubs): - channel = await CHANNELS.get(pub_id, self._graph_address) - channel.subscribe(self.id, self._incoming) + channel = await CHANNELS.subscribe(pub_id, self.id, self._incoming, self._graph_address) + self._channels[pub_id] = channel for pub_id in set(self._cur_pubs - pub_ids): - channel = await CHANNELS.get(pub_id, self._graph_address) - channel.unsubscribe(self.id) + channel = await CHANNELS.unsubscribe(pub_id, self.id, self._graph_address) + del self._channels[pub_id] writer.write(Command.COMPLETE.value) await writer.drain() @@ -106,7 +113,7 @@ async def _graph_connection( logger.debug(f"Subscriber {self.id} lost connection to graph server") finally: - CHANNELS.unsubscribe_all(self.id, self._graph_address) # good idea? + await CHANNELS.unsubscribe_all(self.id, self._graph_address) await close_stream_writer(writer) async def recv(self) -> typing.Any: @@ -117,10 +124,9 @@ async def recv(self) -> typing.Any: @asynccontextmanager async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: - id, msg_id = await self._incoming.get() + pub_id, msg_id = await self._incoming.get() - channel = await CHANNELS.get(id, self._graph_address) - with channel.get(msg_id, self.id) as msg: + with self._channels[pub_id].get(msg_id, self.id) as msg: yield msg diff --git a/tests/test_new_messaging.py b/tests/test_new_messaging.py index 9dec127b..782187ec 100644 --- a/tests/test_new_messaging.py +++ b/tests/test_new_messaging.py @@ -48,18 +48,30 @@ async def host(host: str = '127.0.0.1'): graph_service = GraphService((host, PORT)) await graph_service.ensure() - test_pub = await Publisher.create(TOPIC, (host, PORT), host=host) - test_sub1 = await Subscriber.create(TOPIC, (host, PORT)) - test_sub2 = await Subscriber.create(TOPIC, (host, PORT)) + try: - await asyncio.sleep(1.0) + test_pub = await Publisher.create(TOPIC, (host, PORT), host=host) + test_sub1 = await Subscriber.create(TOPIC, (host, PORT)) + test_sub2 = await Subscriber.create(TOPIC, (host, PORT)) - pub_task = asyncio.Task(handle_pub(test_pub)) - sub_task_1 = asyncio.Task(handle_sub(test_sub1)) - sub_task_2 = asyncio.Task(handle_sub(test_sub2)) + await asyncio.sleep(1.0) + + pub_task = asyncio.Task(handle_pub(test_pub)) + sub_task_1 = asyncio.Task(handle_sub(test_sub1)) + sub_task_2 = asyncio.Task(handle_sub(test_sub2)) - try: await asyncio.wait([pub_task, sub_task_1, sub_task_2]) + + test_pub.close() + test_sub1.close() + test_sub2.close() + + for future in asyncio.as_completed([ + test_pub.wait_closed(), + test_sub1.wait_closed(), + test_sub2.wait_closed(), + ]): + await future finally: server.stop() @@ -76,10 +88,19 @@ async def attach_client(host: str = '127.0.0.1'): sub = await Subscriber.create(TOPIC, (host, PORT)) - while True: - async with sub.recv_zero_copy() as msg: - await asyncio.sleep(1.0) - print(msg) + try: + while True: + async with sub.recv_zero_copy() as msg: + await asyncio.sleep(1.0) + print(msg) + + except asyncio.CancelledError: + pass + + finally: + sub.close() + await sub.wait_closed() + print(f'Detached') if __name__ == '__main__': diff --git a/tests/test_reconstitute.py b/tests/test_reconstitute.py index 8b1e418f..f81c04bd 100644 --- a/tests/test_reconstitute.py +++ b/tests/test_reconstitute.py @@ -1,7 +1,11 @@ import asyncio from copy import deepcopy import timeit -import matplotlib.pyplot as plt + +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None from ezmsg.core.graphserver import GraphServer, GraphService from ezmsg.util.messages.axisarray import AxisArray @@ -84,31 +88,32 @@ def main(): server.stop() # Plot results - plt.figure(figsize=(12, 8)) - for msg_size_bytes in msg_size_bytes_list: - fanouts = results[msg_size_bytes]['fanouts'] - plt.plot( - fanouts, - results[msg_size_bytes]['total_recon'], - label=f'Recon N times ({msg_size_bytes} bytes)', - marker='o', linestyle='--', alpha=0.7 - ) - plt.plot( - fanouts, - results[msg_size_bytes]['total_deepcopy'], - label=f'Recon 1 + deepcopy N-1 ({msg_size_bytes} bytes)', - marker='x', linestyle='-', alpha=0.7 - ) - plt.xscale('log', base=2) - plt.yscale('log') - plt.xlabel('Fanout (number of consumers)') - plt.ylabel('Total time (ns)') - plt.title('Fanout Tradeoff: Reconstitute N times vs Reconstitute+Deepcopy') - plt.legend(fontsize='small', ncol=2) - plt.grid(True, which='both', ls='--', alpha=0.5) - plt.tight_layout() - plt.savefig('fanout_tradeoff.png') - print('Plot saved as fanout_tradeoff.png') + if plt is not None: + plt.figure(figsize=(12, 8)) + for msg_size_bytes in msg_size_bytes_list: + fanouts = results[msg_size_bytes]['fanouts'] + plt.plot( + fanouts, + results[msg_size_bytes]['total_recon'], + label=f'Recon N times ({msg_size_bytes} bytes)', + marker='o', linestyle='--', alpha=0.7 + ) + plt.plot( + fanouts, + results[msg_size_bytes]['total_deepcopy'], + label=f'Recon 1 + deepcopy N-1 ({msg_size_bytes} bytes)', + marker='x', linestyle='-', alpha=0.7 + ) + plt.xscale('log', base=2) + plt.yscale('log') + plt.xlabel('Fanout (number of consumers)') + plt.ylabel('Total time (ns)') + plt.title('Fanout Tradeoff: Reconstitute N times vs Reconstitute+Deepcopy') + plt.legend(fontsize='small', ncol=2) + plt.grid(True, which='both', ls='--', alpha=0.5) + plt.tight_layout() + plt.savefig('fanout_tradeoff.png') + print('Plot saved as fanout_tradeoff.png') if __name__ == '__main__': main() \ No newline at end of file From 35f8a10d0023b017d4ae6856a66cd4841e4fd36c Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 21 Aug 2025 19:13:39 -0400 Subject: [PATCH 28/88] lots of bugfixes --- src/ezmsg/core/graphcontext.py | 9 +- src/ezmsg/core/graphserver.py | 5 +- src/ezmsg/core/messagechannel.py | 148 ++++++++++++++++--------------- src/ezmsg/core/pubclient.py | 40 +++++---- src/ezmsg/core/shm.py | 82 ++++++++--------- src/ezmsg/core/subclient.py | 18 ++-- tests/test_graph.py | 32 ++++++- 7 files changed, 196 insertions(+), 138 deletions(-) diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 4dc91e2d..38af8c22 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -97,13 +97,16 @@ async def __aexit__( return False async def revert(self) -> None: + logger.info('revert graphcontext: close clients') for client in self._clients: client.close() + await client.wait_closed() - wait = [c.wait_closed() for c in self._clients] - for future in asyncio.as_completed(wait): - await future + # wait = [c.wait_closed() for c in self._clients] + # for future in asyncio.as_completed(wait): + # await future + logger.info('revert graphcontext: disconnect edges') for edge in self._edges: try: await GraphService(self.graph_address).disconnect(*edge) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 379f4978..6f3b424e 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -49,7 +49,6 @@ def __init__(self) -> None: self.graph = DAG() self.clients = dict() self._client_tasks = dict() - self.shms = dict() async def setup(self) -> None: @@ -189,6 +188,9 @@ async def api( writer.write(Command.COMPLETE.value) writer.write(encode_str(str(client_id))) + + # Wait until pub's channel server is up before + # notifying subs for sub in self._downstream_subs(info.topic): await self._notify_subscriber(sub) @@ -296,7 +298,6 @@ async def _notify_subscriber(self, sub: SubscriberInfo) -> None: writer.write(Command.UPDATE.value) writer.write(encode_str(notify_str)) - except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"Failed to update Subscriber {sub.id}: {e}") diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 69935e9d..ecf0ee40 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -44,7 +44,7 @@ class _Channel: cache: typing.List[typing.Any] cache_id: typing.List[int | None] shm: SHMContext | None - subs: typing.Dict[UUID, NotificationQueue] + clients: typing.Dict[UUID, NotificationQueue | None] backpressure: Backpressure _graph_task: asyncio.Task[None] @@ -68,7 +68,7 @@ def __init__( self.cache_id = [None] * self.num_buffers self.cache = [None] * self.num_buffers self.backpressure = Backpressure(self.num_buffers) - self.subs = dict() + self.clients = dict() self._graph_address = graph_address @classmethod @@ -77,6 +77,7 @@ async def create( pub_id: UUID, graph_address: AddressType, ) -> "_Channel": + logger.info(f'attempting to create channel {pub_id=}') graph_service = GraphService(graph_address) graph_reader, graph_writer = await graph_service.open_connection() @@ -91,7 +92,6 @@ async def create( pub_address = await Address.from_stream(graph_reader) reader, writer = await asyncio.open_connection(*pub_address) - writer.write(Command.CHANNEL.value) writer.write(encode_str(id_str)) @@ -113,14 +113,18 @@ async def create( chan = cls(UUID(id_str), pub_id, num_buffers, shm) chan._graph_task = asyncio.create_task( - chan._graph_connection(graph_reader, graph_writer) + chan._graph_connection(graph_reader, graph_writer), + name = f'chan-{chan.id}: _graph_connection' ) chan._pub_writer = writer chan._pub_task = asyncio.create_task( - chan._publisher_connection(reader) + chan._publisher_connection(reader), + name = f'chan-{chan.id}: _publisher_connection' ) + logger.info(f'created channel {pub_id=}') + return chan def close(self) -> None: @@ -185,7 +189,7 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: self.cache[buf_idx] = obj self.cache_id[buf_idx] = msg_id - self._notify_subs(msg_id) + self._notify_clients(msg_id) except (ConnectionResetError, BrokenPipeError): logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") @@ -194,9 +198,10 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: await close_stream_writer(self._pub_writer) logger.debug(f"disconnected: channel:{self.id} -> pub:{id}") - def _notify_subs(self, msg_id: int) -> None: - for sub_id, queue in self.subs.items(): - self.backpressure.lease(sub_id, msg_id % self.num_buffers) + def _notify_clients(self, msg_id: int) -> None: + for client_id, queue in self.clients.items(): + if queue is None: continue + self.backpressure.lease(client_id, msg_id % self.num_buffers) queue.put_nowait((self.pub_id, msg_id)) def put_local(self, msg_id: int, msg: typing.Any) -> None: @@ -204,10 +209,10 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: buf_idx = msg_id % self.num_buffers self.cache_id[buf_idx] = msg_id self.cache[buf_idx] = msg - self._notify_subs(msg_id) + self._notify_clients(msg_id) @contextmanager - def get(self, msg_id: int, sub_id: UUID) -> typing.Generator[typing.Any, None, None]: + def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: """get object from cache; if not in cache and shm provided -- get from shm""" buf_idx = msg_id % self.num_buffers @@ -223,13 +228,27 @@ def get(self, msg_id: int, sub_id: UUID) -> typing.Generator[typing.Any, None, N raise CacheMiss with MessageMarshal.obj_from_mem(mem) as obj: - # Could deepcopy and put in cache here, but - # profiling indicates its faster to repeatedly - # reconstruct from memory for fanout <= 4 subs + # Could deepcopy and put in cache here... + # Profiling indicates its faster to repeatedly + # reconstruct <=512kB msgs from memory for up to + # about 4 subs -- deepcopy is about 4x slower # which I suspect will be majority of cases + # With more than 4 subs, one deepcopy may be faster + # for <=512kB msgs, but becomes remarkably time + # consuming for >= 512kB msgs. + + # TODO: Implement a tuning method that runs on + # first ezmsg run that determines this tradeoff + # and decides fastest behavior; then decide + # to cache or not cache depending on this profiling + # for this particular machine + + # self.shm.buf_size # Buffer size + # len(self.clients) # Channel fanout + yield obj - self.backpressure.free(sub_id, buf_idx) + self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: try: ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) @@ -237,21 +256,22 @@ def get(self, msg_id: int, sub_id: UUID) -> typing.Generator[typing.Any, None, N except (BrokenPipeError, ConnectionResetError): logger.debug(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") - def subscribe(self, sub_id: UUID, sub_queue: NotificationQueue) -> None: - self.subs[sub_id] = sub_queue + def register_client(self, client_id: UUID, queue: NotificationQueue | None = None) -> None: + self.clients[client_id] = queue - def unsubscribe(self, sub_id: UUID) -> None: - queue = self.subs[sub_id] + def unregister_client(self, client_id: UUID) -> None: + queue = self.clients[client_id] - for _ in range(queue.qsize()): - pub_id, msg_id = queue.get_nowait() - if pub_id == self.pub_id: - continue - queue.put_nowait((pub_id, msg_id)) + if queue is not None: + for _ in range(queue.qsize()): + pub_id, msg_id = queue.get_nowait() + if pub_id == self.pub_id: + continue + queue.put_nowait((pub_id, msg_id)) - self.backpressure.free(sub_id) + self.backpressure.free(client_id) - del self.subs[sub_id] + del self.clients[client_id] def clear_cache(self): self.cache_id = [None] * self.num_buffers @@ -270,64 +290,52 @@ def _ensure_address(address: AddressType | None) -> Address: class _ChannelManager: + _lock: asyncio.Lock _registry: typing.Dict[Address, typing.Dict[UUID, _Channel]] def __init__(self): default_address = Address.from_string(GRAPHSERVER_ADDR) self._registry = {default_address: dict()} + self._lock = asyncio.Lock() - async def get( - self, - pub_id: UUID, - graph_address: AddressType | None = None, - create: bool = True - ) -> _Channel: - graph_address = _ensure_address(graph_address) - channels = self._registry.get(graph_address, dict()) - channel = channels.get(pub_id, None) - if create and channel is None: - channel = await _Channel.create(pub_id, graph_address) - channels[pub_id] = channel - self._registry[graph_address] = channels - if channel is None: - raise ValueError("Channel does not exist") - return channel - - async def subscribe( + async def register( self, pub_id: UUID, - sub_id: UUID, - sub_queue: NotificationQueue, + client_id: UUID, + queue: NotificationQueue | None = None, graph_address: AddressType | None = None ) -> _Channel: - channel = await self.get(pub_id, graph_address, create = True) - channel.subscribe(sub_id, sub_queue) - return channel - - async def unsubscribe_all( - self, - sub_id: UUID, - graph_address: AddressType | None = None - ) -> None: - graph_address = _ensure_address(graph_address) - channels = self._registry.get(graph_address, dict()) - for pub_id, channel in channels.items(): - if sub_id in channel.subs: - await self.unsubscribe(pub_id, sub_id, graph_address) - - async def unsubscribe( + async with self._lock: + logger.info(f'register ch {pub_id=} {client_id=}') + graph_address = _ensure_address(graph_address) + channels = self._registry.get(graph_address, dict()) + channel = channels.get(pub_id, None) + if channel is None: + channel = await _Channel.create(pub_id, graph_address) + channels[pub_id] = channel + self._registry[graph_address] = channels + channel.register_client(client_id, queue) + logger.info(f'ch {pub_id=} {client_id=} reg DONE') + return channel + + async def unregister( self, pub_id: UUID, - sub_id: UUID, + client_id: UUID, graph_address: AddressType | None = None ) -> None: - graph_address = _ensure_address(graph_address) - channel = self._registry[graph_address][pub_id] - channel.unsubscribe(sub_id) - if len(channel.subs) == 0: - channel.close() - await channel.wait_closed() - logger.debug(f'closed channel {channel.id}: no subs') + async with self._lock: + logger.info(f'unregister ch {pub_id=} {client_id=} unreg') + graph_address = _ensure_address(graph_address) + registry = self._registry[graph_address] + channel = registry[pub_id] + channel.unregister_client(client_id) + if len(channel.clients) == 0: + channel.close() + await channel.wait_closed() + del registry[pub_id] + logger.info(f'closed channel {pub_id}: no clients') + logger.info(f'ch {pub_id=} {client_id=} unreg DONE') CHANNELS = _ChannelManager() \ No newline at end of file diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 681394f7..5bc651af 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -83,6 +83,7 @@ async def create( **kwargs, ) -> "Publisher": # We have to fill in some parts of this class using async + logger.info(f'attempting to create pub {topic=}') pub = cls(topic, graph_address, num_buffers, **kwargs) graph_service = GraphService(graph_address) @@ -107,27 +108,33 @@ async def create( pub.id = UUID(await read_str(reader)) - pub._graph_task = asyncio.create_task(pub._graph_connection(reader, writer)) - - async def serve() -> None: - try: - await server.serve_forever() - except asyncio.CancelledError: - logger.debug("{pub.log_name} cancelled") - finally: - await close_server(server) - - pub._connection_task = asyncio.create_task(serve(), name=pub.log_name) + pub._graph_task = asyncio.create_task( + pub._graph_connection(reader, writer), + name = f'pub-{pub.id}: _graph_connection' + ) - def on_done(_: asyncio.Future) -> None: - logger.debug("{pub.log_name} done") + pub._connection_task = asyncio.create_task( + pub._serve_channels(server), + name = f'pub-{pub.id}: pub.log_name' + ) - pub._connection_task.add_done_callback(on_done) + pub._local_channel = await CHANNELS.register( + pub.id, pub.id, None, pub._graph_address + ) - # Create the local Channel (it shouldn't already exist) - pub._local_channel = await CHANNELS.get(pub.id, graph_address, create = True) + logger.info(f'created pub {topic=} {pub.id=}') return pub + + async def _serve_channels(self, server: asyncio.Server) -> None: + try: + await server.serve_forever() + except asyncio.CancelledError: + logger.debug(f"{self.log_name} cancelled") + finally: + await close_server(server) + await CHANNELS.unregister(self.id, self.id, self._graph_address) + logger.debug(f"{self.log_name} done") def __init__( self, @@ -151,7 +158,6 @@ def __init__( self._backpressure = Backpressure(num_buffers) self._force_tcp = force_tcp self._last_backpressure_event = -1 - self._graph_address = graph_address @property diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index b7951973..3bee0c1d 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -57,13 +57,12 @@ class SHMContext: exposes memoryview objects for reading and writing """ - _shm: SharedMemory - _data_block_segs: typing.List[slice] - num_buffers: int buf_size: int - monitor: asyncio.Future + _shm: SharedMemory + _data_block_segs: typing.List[slice] + _graph_task: asyncio.Task[None] def __init__(self, name: str) -> None: with _untracked_shm(): @@ -88,29 +87,30 @@ def attach( cls, shm_name: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "SHMContext": context = cls(shm_name) - - async def monitor() -> None: - try: - await reader.read() - logger.debug("Read from SHMContext monitor reader") - except asyncio.CancelledError: - pass - finally: - await close_stream_writer(writer) - - def close(_: asyncio.Future) -> None: - context.close() - - context.monitor = asyncio.create_task(monitor(), name=f"{shm_name}_monitor") - context.monitor.add_done_callback(close) + context._graph_task = asyncio.create_task( + context._graph_connection(reader, writer), + name=f"{context.name}_monitor" + ) return context + + async def _graph_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + await reader.read() + logger.debug(f"SHMContext {self.name} GraphServer hangup") + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"SHMContext {self.name} GraphServer {type(e)}") + finally: + await close_stream_writer(writer) + self._shm.close() @contextmanager def buffer( self, idx: int, readonly: bool = False ) -> typing.Generator[memoryview, None, None]: if self._shm.buf is None: - raise BufferError(f"cannot access {self._shm.name}: server disconnected") + raise BufferError(f"cannot access {self.name}: server disconnected") with self._shm.buf[self._data_block_segs[idx]] as mem: if readonly: @@ -121,22 +121,11 @@ def buffer( yield mem def close(self) -> None: - asyncio.create_task(self.close_shm(), name=f"Close {self._shm.name}") - self.monitor.cancel() - - async def close_shm(self) -> None: - while True: - try: - self._shm.close() - logger.debug("Closed SHM segment.") - return - except BufferError: - logger.debug("BufferError caught... Sleeping.") - await asyncio.sleep(0.1) + self._graph_task.cancel() async def wait_closed(self) -> None: with suppress(asyncio.CancelledError): - await self.monitor + await self._graph_task @property def name(self) -> str: @@ -145,6 +134,17 @@ def name(self) -> str: @property def size(self) -> int: return self.buf_size - 16 # 16 byte header + + # This seems like it shouldn't be a thing. + # async def close_shm(self) -> None: + # while True: + # try: + # self._shm.close() + # logger.debug("Closed SHM segment.") + # return + # except BufferError: + # logger.debug("BufferError caught... Sleeping.") + # await asyncio.sleep(0.1) @dataclass @@ -166,16 +166,18 @@ def create( def lease( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "asyncio.Task[None]": - async def _wait_for_eof() -> None: - try: - await reader.read() - finally: - await close_stream_writer(writer) - - lease = asyncio.create_task(_wait_for_eof()) + lease = asyncio.create_task(self._wait_for_eof(reader, writer)) lease.add_done_callback(self._release) self.leases.add(lease) return lease + + async def _wait_for_eof( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + await reader.read() + finally: + await close_stream_writer(writer) def _release(self, task: "asyncio.Task[None]"): self.leases.discard(task) diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index fb299398..acb791af 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -26,7 +26,7 @@ class Subscriber: topic: str _graph_address: AddressType | None - _graph_task: "asyncio.Task[None]" + _graph_task: asyncio.Task[None] _cur_pubs: typing.Set[UUID] _incoming: NotificationQueue @@ -42,6 +42,7 @@ async def create( graph_address: AddressType | None, **kwargs ) -> "Subscriber": + logger.info(f'attempting to create sub {topic=}') reader, writer = await GraphService(graph_address).open_connection() writer.write(Command.SUBSCRIBE.value) writer.write(encode_str(topic)) @@ -53,7 +54,12 @@ async def create( id_str = await read_str(reader) sub = cls(UUID(id_str), topic, graph_address, **kwargs) - sub._graph_task = asyncio.create_task(sub._graph_connection(reader, writer)) + sub._graph_task = asyncio.create_task( + sub._graph_connection(reader, writer), + name = f'sub-{sub.id}: _graph_connection' + ) + + logger.info(f'created sub {topic=} {sub.id=}') return sub @@ -92,13 +98,14 @@ async def _graph_connection( elif cmd == Command.UPDATE.value: update = await read_str(reader) pub_ids = set([UUID(id) for id in update.split(',')]) if update else set() + logger.info(f'{pub_ids=}') for pub_id in set(pub_ids - self._cur_pubs): - channel = await CHANNELS.subscribe(pub_id, self.id, self._incoming, self._graph_address) + channel = await CHANNELS.register(pub_id, self.id, self._incoming, self._graph_address) self._channels[pub_id] = channel for pub_id in set(self._cur_pubs - pub_ids): - channel = await CHANNELS.unsubscribe(pub_id, self.id, self._graph_address) + await CHANNELS.unregister(pub_id, self.id, self._graph_address) del self._channels[pub_id] writer.write(Command.COMPLETE.value) @@ -113,7 +120,8 @@ async def _graph_connection( logger.debug(f"Subscriber {self.id} lost connection to graph server") finally: - await CHANNELS.unsubscribe_all(self.id, self._graph_address) + for pub_id in self._channels: + await CHANNELS.unregister(pub_id, self.id, self._graph_address) await close_stream_writer(writer) async def recv(self) -> typing.Any: diff --git a/tests/test_graph.py b/tests/test_graph.py index b7b414ac..c3bb3839 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -16,7 +16,33 @@ ("d", "e"), ] -simple_graph_2 = [("w", "x"), ("w", "y"), ("x", "z"), ("y", "z")] +simple_graph_2 = [ + ("w", "x"), + ("w", "y"), + ("x", "z"), + ("y", "z") +] + +@pytest.mark.asyncio +async def test_pub_first(): + async with GraphContext() as context: + await context.publisher("a") + await asyncio.sleep(0.1) + await context.connect("a", "b") + await asyncio.sleep(0.1) + await context.publisher("b") + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_sub_first(): + async with GraphContext() as context: + await context.subscriber("b") + await asyncio.sleep(0.1) + await context.connect("a", "b") + await asyncio.sleep(0.1) + await context.publisher("a") + await asyncio.sleep(0.1) @pytest.mark.asyncio @@ -28,9 +54,13 @@ async def test_graph(): for edge in simple_graph_1: await context.connect(*edge) + await asyncio.sleep(0.1) + await context.subscriber("c") await context.publisher("b") + await asyncio.sleep(0.1) + for edge in simple_graph_2: await context.connect(*edge) From f4bb029de45e8fedf3d882779df0d7b81e26ba12 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 22 Aug 2025 12:16:59 -0400 Subject: [PATCH 29/88] painful bugfixes --- src/ezmsg/core/graphcontext.py | 9 ++--- src/ezmsg/core/graphserver.py | 62 +++++++++++++++++++------------- src/ezmsg/core/messagechannel.py | 30 ++++++++++------ src/ezmsg/core/pubclient.py | 40 +++++++++++---------- src/ezmsg/core/server.py | 13 ++++--- src/ezmsg/core/shm.py | 2 +- src/ezmsg/core/subclient.py | 59 ++++++++++++++++++++++-------- tests/test_graph.py | 15 +++++++- 8 files changed, 147 insertions(+), 83 deletions(-) diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 38af8c22..4dc91e2d 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -97,16 +97,13 @@ async def __aexit__( return False async def revert(self) -> None: - logger.info('revert graphcontext: close clients') for client in self._clients: client.close() - await client.wait_closed() - # wait = [c.wait_closed() for c in self._clients] - # for future in asyncio.as_completed(wait): - # await future + wait = [c.wait_closed() for c in self._clients] + for future in asyncio.as_completed(wait): + await future - logger.info('revert graphcontext: disconnect edges') for edge in self._edges: try: await GraphService(self.graph_address).disconnect(*edge) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 6f3b424e..5bb5a519 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -117,6 +117,14 @@ async def api( writer.write(Command.COMPLETE.value) writer.write(encode_str(shm_info.shm.name)) + # NOTE: SHMContexts are like GraphClients in that their + # lifetime is bound to the lifetime of this connection + # but rather than the early return that we have with + # GraphClients, we just await the lease task. + # This may be a more readable pattern? + # NOTE: With the current shutdown pattern, we cancel + # the lease task then await it. When we cancel the lease + # task this resolves then it gets awaited in shutdown. with suppress(asyncio.CancelledError): await shm_info.lease(reader, writer) @@ -129,7 +137,7 @@ async def api( try: pub_info = self.clients[pub_id] if isinstance(pub_info, PublisherInfo): - # assemble an address the channel should be able to resolve + # Advertise an address the channel should be able to resolve port = pub_info.address.port iface = pub_info.writer.transport.get_extra_info("peername")[0] pub_addr = Address(iface, port) @@ -139,21 +147,26 @@ async def api( except KeyError: logger.warning(f"Connecting channel requested non-existent publisher {pub_id=}") - if pub_addr is not None: - writer.write(Command.COMPLETE.value) - writer.write(encode_str(str(channel_id))) - pub_addr.to_stream(writer) - - info = ChannelInfo(channel_id, writer, pub_id) - self.clients[channel_id] = info - self._client_tasks[channel_id] = asyncio.create_task( - self._handle_client(channel_id, reader, writer) - ) - else: - # Error, drop connection + if pub_addr is None: + # FIXME: Channel should not be created; but it feels like we should + # have a better communication protocol to tell the channel what the + # error was and deliver a better experience from the client side. + # for now, drop connection await close_stream_writer(writer) + return - # Created a client; we must return early to avoid closing writer + writer.write(Command.COMPLETE.value) + writer.write(encode_str(str(channel_id))) + pub_addr.to_stream(writer) + + info = ChannelInfo(channel_id, writer, pub_id) + self.clients[channel_id] = info + self._client_tasks[channel_id] = asyncio.create_task( + self._handle_client(channel_id, reader, writer) + ) + + # NOTE: Created a client, must return early + # to avoid closing writer return else: @@ -166,6 +179,8 @@ async def api( client_id = uuid1() topic = await read_str(reader) + writer.write(encode_str(str(client_id))) + if req == Command.SUBSCRIBE.value: info = SubscriberInfo(client_id, writer, topic) self.clients[client_id] = info @@ -173,9 +188,6 @@ async def api( self._handle_client(client_id, reader, writer) ) - writer.write(Command.COMPLETE.value) - writer.write(encode_str(str(client_id))) - await self._notify_subscriber(info) elif req == Command.PUBLISH.value: @@ -185,17 +197,14 @@ async def api( self._client_tasks[client_id] = asyncio.create_task( self._handle_client(client_id, reader, writer) ) - - writer.write(Command.COMPLETE.value) - writer.write(encode_str(str(client_id))) - # Wait until pub's channel server is up before - # notifying subs - for sub in self._downstream_subs(info.topic): await self._notify_subscriber(sub) - # Created a client, must return early to avoid closing writer + writer.write(Command.COMPLETE.value) + + # NOTE: Created a client, must return early + # to avoid closing writer return elif req in [Command.CONNECT.value, Command.DISCONNECT.value]: @@ -250,6 +259,11 @@ async def api( except (ConnectionResetError, BrokenPipeError): logger.debug("GraphServer connection fail mid-command") + # NOTE: This prevents code repetition for many graph server commands, but + # when we create GraphClients, their lifecycle is bound to the lifecycle of + # this connection. We do NOT want to close the stream writer if we have + # created a GraphClient, which requires an early return. Perhaps a different + # communication protocol could resolve this await close_stream_writer(writer) async def _handle_client( diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index ecf0ee40..557bb05f 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -77,7 +77,6 @@ async def create( pub_id: UUID, graph_address: AddressType, ) -> "_Channel": - logger.info(f'attempting to create channel {pub_id=}') graph_service = GraphService(graph_address) graph_reader, graph_writer = await graph_service.open_connection() @@ -86,6 +85,9 @@ async def create( response = await graph_reader.read(1) if response != Command.COMPLETE.value: + # FIXME: This will happen if the channel requested connection + # to a non-existent (or non-publisher) UUID. Ideally GraphServer + # would tell us what happened rather than drop connection raise ValueError(f'failed to create channel {pub_id=}') id_str = await read_str(graph_reader) @@ -106,6 +108,8 @@ async def create( result = await reader.read(1) if result != Command.COMPLETE.value: + # NOTE: The only reason this would happen is if the + # publisher's writer is closed due to a crash or shutdown raise ValueError(f'failed to create channel {pub_id=}') num_buffers = await read_int(reader) @@ -123,15 +127,19 @@ async def create( name = f'chan-{chan.id}: _publisher_connection' ) - logger.info(f'created channel {pub_id=}') + logger.debug(f'created channel {chan.id=} {pub_id=}') return chan def close(self) -> None: + if self.shm is not None: + self.shm.close() self._graph_task.cancel() self._pub_task.cancel() async def wait_closed(self) -> None: + if self.shm is not None: + await self.shm.wait_closed() with suppress(asyncio.CancelledError): await self._graph_task with suppress(asyncio.CancelledError): @@ -173,6 +181,7 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: if self.shm is not None and self.shm.name != shm_name: self.shm.close() + await self.shm.wait_closed() try: self.shm = await GraphService(self._graph_address).attach_shm(shm_name) except ValueError: @@ -241,11 +250,12 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None # first ezmsg run that determines this tradeoff # and decides fastest behavior; then decide # to cache or not cache depending on this profiling - # for this particular machine - + # for this particular machine. + # NOTE: We have information about the message size and + # the current fanout at time of RX, so we could + # intelligently copy here! # self.shm.buf_size # Buffer size # len(self.clients) # Channel fanout - yield obj self.backpressure.free(client_id, buf_idx) @@ -254,7 +264,7 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) self._pub_writer.write(ack) except (BrokenPipeError, ConnectionResetError): - logger.debug(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") + logger.info(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") def register_client(self, client_id: UUID, queue: NotificationQueue | None = None) -> None: self.clients[client_id] = queue @@ -306,7 +316,7 @@ async def register( graph_address: AddressType | None = None ) -> _Channel: async with self._lock: - logger.info(f'register ch {pub_id=} {client_id=}') + logger.debug(f'ch {pub_id=} {client_id=} reg DONE') graph_address = _ensure_address(graph_address) channels = self._registry.get(graph_address, dict()) channel = channels.get(pub_id, None) @@ -315,7 +325,6 @@ async def register( channels[pub_id] = channel self._registry[graph_address] = channels channel.register_client(client_id, queue) - logger.info(f'ch {pub_id=} {client_id=} reg DONE') return channel async def unregister( @@ -325,7 +334,7 @@ async def unregister( graph_address: AddressType | None = None ) -> None: async with self._lock: - logger.info(f'unregister ch {pub_id=} {client_id=} unreg') + logger.debug(f'ch {pub_id=} {client_id=} unreg') graph_address = _ensure_address(graph_address) registry = self._registry[graph_address] channel = registry[pub_id] @@ -334,8 +343,7 @@ async def unregister( channel.close() await channel.wait_closed() del registry[pub_id] - logger.info(f'closed channel {pub_id}: no clients') - logger.info(f'ch {pub_id=} {client_id=} unreg DONE') + logger.debug(f'closed channel {pub_id}: no clients') CHANNELS = _ChannelManager() \ No newline at end of file diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 5bc651af..5f2852e3 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -82,47 +82,46 @@ async def create( num_buffers: int = 32, **kwargs, ) -> "Publisher": - # We have to fill in some parts of this class using async - logger.info(f'attempting to create pub {topic=}') - pub = cls(topic, graph_address, num_buffers, **kwargs) - graph_service = GraphService(graph_address) reader, writer = await graph_service.open_connection() - pub._shm = await graph_service.create_shm(num_buffers, buf_size) + shm = await graph_service.create_shm(num_buffers, buf_size) start_port = int( os.getenv(PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT) ) + sock = create_socket(host, port, start_port=start_port) - address = Address(*sock.getsockname()) - server = await asyncio.start_server(pub._channel_connect, sock=sock) - writer.write(Command.PUBLISH.value) writer.write(encode_str(topic)) - address.to_stream(writer) - result = await reader.read(1) + pub_id = UUID(await read_str(reader)) + pub = cls(pub_id, topic, shm, graph_address, num_buffers, **kwargs) + + server = await asyncio.start_server(pub._channel_connect, sock=sock) + pub._connection_task = asyncio.create_task( + pub._serve_channels(server), + name = f'pub-{pub.id}: {pub.topic}' + ) + + # Notify GraphServer that our server is up + channel_server_address = Address(*sock.getsockname()) + channel_server_address.to_stream(writer) + result = await reader.read(1) # channels connect if result != Command.COMPLETE.value: logger.warning(f'Could not create publisher {topic=}') - pub.id = UUID(await read_str(reader)) - + # Pass off graph connection keep-alive to publisher task pub._graph_task = asyncio.create_task( pub._graph_connection(reader, writer), name = f'pub-{pub.id}: _graph_connection' ) - pub._connection_task = asyncio.create_task( - pub._serve_channels(server), - name = f'pub-{pub.id}: pub.log_name' - ) - pub._local_channel = await CHANNELS.register( pub.id, pub.id, None, pub._graph_address ) - logger.info(f'created pub {topic=} {pub.id=}') + logger.debug(f'created pub {pub.id=} {topic=}') return pub @@ -138,16 +137,19 @@ async def _serve_channels(self, server: asyncio.Server) -> None: def __init__( self, + id: UUID, topic: str, + shm: SHMContext, graph_address: AddressType | None = None, num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, ) -> None: """DO NOT USE this constructor to make a Publisher; use `create` instead""" + self.id = id self.pid = os.getpid() self.topic = topic - + self._shm = shm self._msg_id = 0 self._channels = dict() self._channel_tasks = dict() diff --git a/src/ezmsg/core/server.py b/src/ezmsg/core/server.py index c1c1461c..36fd03c6 100644 --- a/src/ezmsg/core/server.py +++ b/src/ezmsg/core/server.py @@ -3,15 +3,14 @@ import logging import socket import typing +import threading from contextlib import suppress -from threading import Thread, Event from .netprotocol import ( Address, AddressType, DEFAULT_HOST, - close_server, close_stream_writer, create_socket, SERVER_PORT_START_ENV, @@ -21,19 +20,19 @@ logger = logging.getLogger("ezmsg") -class ThreadedAsyncServer(Thread): +class ThreadedAsyncServer(threading.Thread): """An asyncio server that runs in a dedicated loop in a separate thread""" - _server_up: Event - _shutdown: Event + _server_up: threading.Event + _shutdown: threading.Event _sock: socket.socket _loop: asyncio.AbstractEventLoop def __init__(self) -> None: super().__init__(daemon=True) - self._server_up = Event() - self._shutdown = Event() + self._server_up = threading.Event() + self._shutdown = threading.Event() @property def address(self) -> Address: diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 3bee0c1d..89f9c7e0 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -126,7 +126,7 @@ def close(self) -> None: async def wait_closed(self) -> None: with suppress(asyncio.CancelledError): await self._graph_task - + @property def name(self) -> str: return self._shm.name diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index acb791af..ebed727e 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -30,9 +30,14 @@ class Subscriber: _cur_pubs: typing.Set[UUID] _incoming: NotificationQueue - # This is an optimization to retain a local handle to channels - # so that dict lookup and wrapper contextmanager aren't in - # the hotpath - Griff + # FIXME: This event allows Subscriber.create to block until + # incoming initial connections (UPDATE) has completed. The + # logic is confusing and difficult to follow, but this event + # serves an important purpose for the time being. + _initialized: asyncio.Event + + # NOTE: This is an optimization to retain a local handle to channels + # so that dict lookup and wrapper contextmanager aren't in hotpath _channels: typing.Dict[UUID, _Channel] @classmethod @@ -42,28 +47,30 @@ async def create( graph_address: AddressType | None, **kwargs ) -> "Subscriber": - logger.info(f'attempting to create sub {topic=}') reader, writer = await GraphService(graph_address).open_connection() writer.write(Command.SUBSCRIBE.value) writer.write(encode_str(topic)) + sub_id_str = await read_str(reader) + sub_id = UUID(sub_id_str) + + sub = cls(sub_id, topic, graph_address, **kwargs) - result = await reader.read(1) - if result != Command.COMPLETE.value: - logger.warning(f'Could not create subscriber {topic=}') - - id_str = await read_str(reader) - - sub = cls(UUID(id_str), topic, graph_address, **kwargs) sub._graph_task = asyncio.create_task( sub._graph_connection(reader, writer), name = f'sub-{sub.id}: _graph_connection' ) - logger.info(f'created sub {topic=} {sub.id=}') + # FIXME: We need to wait for _graph_task to service an UPDATE + # then receive a COMPLETE before we return a fully connected + # subscriber ready for recv. + await sub._initialized.wait() + + logger.debug(f'created sub {sub.id=} {topic=}') return sub - def __init__(self, + def __init__( + self, id: UUID, topic: str, graph_address: AddressType | None, @@ -77,6 +84,7 @@ def __init__(self, self._cur_pubs = set() self._incoming = asyncio.Queue() self._channels = dict() + self._initialized = asyncio.Event() def close(self) -> None: self._graph_task.cancel() @@ -95,10 +103,33 @@ async def _graph_connection( if not cmd: break + if cmd == Command.COMPLETE.value: + # FIXME: The only time GraphServer will send us a COMPLETE + # is when it is done with Subscriber.create. Unfortunately + # part of creating the subscriber involves receiving an + # UPDATE with all of the initial connections so that + # messaging resolves as expected immediately after creation + # is completed. The only way we can service this UPDATE + # is by passing control of the GraphServer's StreamReader + # to this task, which can handle the UPDATE -- mid-creation + # then we receive the COMPLETE in here, set the _initialized + # event, which we wait on in Subscriber.create, releasing the + # block and returning a fully connected/ready subscriber. + # While this currently works, its non-obvious what this + # accomplishes and why its implemented this way. I just + # wasted 2 hours removing this seemingly un-necessary event + # to introduce a bug that is only resolved by reintroducing + # the event. Some thought should be put into replacing the + # bespoke communication protocol with something a bit more + # standard (JSON RPC?) with a more common/logical pattern + # for creating/handling comms. We will probably want to + # keep the bespoke comms for the publisher->channel link + # as that is in the hot path. + self._initialized.set() + elif cmd == Command.UPDATE.value: update = await read_str(reader) pub_ids = set([UUID(id) for id in update.split(',')]) if update else set() - logger.info(f'{pub_ids=}') for pub_id in set(pub_ids - self._cur_pubs): channel = await CHANNELS.register(pub_id, self.id, self._incoming, self._graph_address) diff --git a/tests/test_graph.py b/tests/test_graph.py index c3bb3839..e8927e70 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -30,7 +30,7 @@ async def test_pub_first(): await asyncio.sleep(0.1) await context.connect("a", "b") await asyncio.sleep(0.1) - await context.publisher("b") + await context.subscriber("b") await asyncio.sleep(0.1) @@ -83,6 +83,19 @@ async def test_graph(): await context.connect("c", "a") +@pytest.mark.asyncio +async def test_comms_simple(): + async with GraphContext() as context: + b_sub = await context.subscriber("b") + await context.connect("a", "b") + a_pub = await context.publisher("a") + await a_pub.broadcast("HELLO") + print('DONE BROADCASTING') + msg = await b_sub.recv() + assert msg == "HELLO" + + + @pytest.mark.asyncio async def test_comms(): async with GraphContext() as context: From 5ac957c18eb56ac5fd2b0352b8a47d0903e2f6cd Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 26 Aug 2025 09:56:06 -0400 Subject: [PATCH 30/88] fixed intermittent failure of test_filter_key --- src/ezmsg/core/messagechannel.py | 43 +++++++++++++++----------------- tests/ez_test_utils.py | 15 ++++++----- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 557bb05f..722ac9d0 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -300,7 +300,6 @@ def _ensure_address(address: AddressType | None) -> Address: class _ChannelManager: - _lock: asyncio.Lock _registry: typing.Dict[Address, typing.Dict[UUID, _Channel]] def __init__(self): @@ -315,17 +314,16 @@ async def register( queue: NotificationQueue | None = None, graph_address: AddressType | None = None ) -> _Channel: - async with self._lock: - logger.debug(f'ch {pub_id=} {client_id=} reg DONE') - graph_address = _ensure_address(graph_address) - channels = self._registry.get(graph_address, dict()) - channel = channels.get(pub_id, None) - if channel is None: - channel = await _Channel.create(pub_id, graph_address) - channels[pub_id] = channel - self._registry[graph_address] = channels - channel.register_client(client_id, queue) - return channel + logger.debug(f'ch {pub_id=} {client_id=} reg DONE') + graph_address = _ensure_address(graph_address) + channels = self._registry.get(graph_address, dict()) + channel = channels.get(pub_id, None) + if channel is None: + channel = await _Channel.create(pub_id, graph_address) + channels[pub_id] = channel + self._registry[graph_address] = channels + channel.register_client(client_id, queue) + return channel async def unregister( self, @@ -333,17 +331,16 @@ async def unregister( client_id: UUID, graph_address: AddressType | None = None ) -> None: - async with self._lock: - logger.debug(f'ch {pub_id=} {client_id=} unreg') - graph_address = _ensure_address(graph_address) - registry = self._registry[graph_address] - channel = registry[pub_id] - channel.unregister_client(client_id) - if len(channel.clients) == 0: - channel.close() - await channel.wait_closed() - del registry[pub_id] - logger.debug(f'closed channel {pub_id}: no clients') + logger.debug(f'ch {pub_id=} {client_id=} unreg') + graph_address = _ensure_address(graph_address) + registry = self._registry[graph_address] + channel = registry[pub_id] + channel.unregister_client(client_id) + if len(channel.clients) == 0: + channel.close() + await channel.wait_closed() + del registry[pub_id] + logger.debug(f'closed channel {pub_id}: no clients') CHANNELS = _ChannelManager() \ No newline at end of file diff --git a/tests/ez_test_utils.py b/tests/ez_test_utils.py index abf6af20..2c50e3f5 100644 --- a/tests/ez_test_utils.py +++ b/tests/ez_test_utils.py @@ -19,14 +19,13 @@ def get_test_fn(test_name: typing.Optional[str] = None, extension: str = "txt") else: test_name = __name__ - file_path = Path(tempfile.gettempdir()) - file_path = file_path / Path(f"{test_name}.{extension}") - - # Create the file - with open(file_path, "w"): - pass - - return file_path + # Create a unique temporary file name to avoid collisions when running the + # full test suite in parallel or when other tests use the same test name. + # Use NamedTemporaryFile with delete=False so callers can open/remove it. + prefix = f"{test_name}-" if test_name else "test-" + tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=f".{extension}", delete=False) + tmp.close() + return Path(tmp.name) # MESSAGE DEFINITIONS From ef073b8205a76a803d203649888d773a88cecc6f Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 16 Sep 2025 14:39:49 -0400 Subject: [PATCH 31/88] Fixed issue with message channels attempting to connect to canonical port while running locally --- src/ezmsg/core/backend.py | 5 +++++ src/ezmsg/core/messagechannel.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 62e3d974..9647da10 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -228,6 +228,11 @@ async def create_graph_context() -> GraphContext: create_graph_context(), loop ).result() + if graph_context._graph_server is None: + logger.info(f'Connected to GraphServer @ {graph_context.graph_address}') + else: + logger.info(f'Spawned LOCAL GraphServer @ {graph_context.graph_address}') + execution_context.create_processes( graph_address=graph_context.graph_address, backend_process=backend_process diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 722ac9d0..e1fb2744 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -58,7 +58,7 @@ def __init__( pub_id: UUID, num_buffers: int, shm: SHMContext | None, - graph_address: AddressType | None = None + graph_address: AddressType | None ) -> None: self.id = id self.pub_id = pub_id @@ -114,7 +114,7 @@ async def create( num_buffers = await read_int(reader) - chan = cls(UUID(id_str), pub_id, num_buffers, shm) + chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address) chan._graph_task = asyncio.create_task( chan._graph_connection(graph_reader, graph_writer), From afd64d863c8915480e743b666d5405e2e125e0f5 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 16 Sep 2025 16:42:43 -0400 Subject: [PATCH 32/88] fixed backpressure from pub --- src/ezmsg/core/messagechannel.py | 51 ++++++++++++++++++++++---------- src/ezmsg/core/pubclient.py | 4 +-- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index e1fb2744..b9ad2b9a 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -209,16 +209,25 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: def _notify_clients(self, msg_id: int) -> None: for client_id, queue in self.clients.items(): - if queue is None: continue + if queue is None: continue # queue is none if this is the pub self.backpressure.lease(client_id, msg_id % self.num_buffers) queue.put_nowait((self.pub_id, msg_id)) - def put_local(self, msg_id: int, msg: typing.Any) -> None: - """put an object into cache (should only be used by Publishers)""" + def put_local(self, msg_id: int, msg: typing.Any) -> bool: + """ + put an object into cache (should only be used by Publishers) + returns true if any clients were notified + """ + self._notify_clients(msg_id) + + # if buffer is still available after notify, no subs were notified here buf_idx = msg_id % self.num_buffers + if self.backpressure.available(buf_idx): + return False + self.cache_id[buf_idx] = msg_id self.cache[buf_idx] = msg - self._notify_clients(msg_id) + return True @contextmanager def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: @@ -307,21 +316,31 @@ def __init__(self): self._registry = {default_address: dict()} self._lock = asyncio.Lock() - async def register( - self, - pub_id: UUID, - client_id: UUID, - queue: NotificationQueue | None = None, - graph_address: AddressType | None = None + async def get( + self, + pub_id: UUID, + graph_address: AddressType | None = None, + create: bool = False ) -> _Channel: - logger.debug(f'ch {pub_id=} {client_id=} reg DONE') graph_address = _ensure_address(graph_address) channels = self._registry.get(graph_address, dict()) channel = channels.get(pub_id, None) - if channel is None: + if create and channel is None: channel = await _Channel.create(pub_id, graph_address) channels[pub_id] = channel self._registry[graph_address] = channels + if channel is None: + raise KeyError("channel does not exist") + return channel + + async def register( + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue | None = None, + graph_address: AddressType | None = None + ) -> _Channel: + channel = await self.get(pub_id, graph_address, create = True) channel.register_client(client_id, queue) return channel @@ -331,14 +350,14 @@ async def unregister( client_id: UUID, graph_address: AddressType | None = None ) -> None: - logger.debug(f'ch {pub_id=} {client_id=} unreg') - graph_address = _ensure_address(graph_address) - registry = self._registry[graph_address] - channel = registry[pub_id] + channel = await self.get(pub_id, graph_address) channel.unregister_client(client_id) + if len(channel.clients) == 0: channel.close() await channel.wait_closed() + graph_address = _ensure_address(graph_address) + registry = self._registry[graph_address] del registry[pub_id] logger.debug(f'closed channel {pub_id}: no clients') diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 5f2852e3..c5bfee1f 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -290,8 +290,8 @@ async def broadcast(self, obj: Any) -> None: # Get local channel and put variable there for local tx if not self._force_tcp: - self._local_channel.put_local(self._msg_id, obj) - self._backpressure.lease(self._local_channel.id, buf_idx) + if self._local_channel.put_local(self._msg_id, obj): + self._backpressure.lease(self._local_channel.id, buf_idx) if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): From 7347424e8d1d2980e3503431a2472d29d670441d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 09:57:27 -0400 Subject: [PATCH 33/88] attach example was broken --- examples/ezmsg_attach.py | 2 +- examples/ezmsg_toy.py | 8 ++++---- src/ezmsg/core/backend.py | 5 ++++- src/ezmsg/core/messagechannel.py | 2 +- src/ezmsg/core/pubclient.py | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/ezmsg_attach.py b/examples/ezmsg_attach.py index 09fe4786..b75241be 100644 --- a/examples/ezmsg_attach.py +++ b/examples/ezmsg_attach.py @@ -5,4 +5,4 @@ if __name__ == "__main__": print("This example attaches to the system created/run by ezmsg_toy.py.") log = DebugLog() - ez.run(log, connections=(("TestSystem/PING/OUTPUT", log.INPUT),)) + ez.run(LOG=log, connections=(("GLOBAL_PING_TOPIC", log.INPUT),)) diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index 3c4238d0..a0dab836 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -192,8 +192,8 @@ def process_components(self): ez.run( SYSTEM=system, - # connections = [ - # ( system.PING.OUTPUT, 'PING_OUTPUT' ), - # ( 'FOO_SUB', system.FOOSUB.INPUT ) - # ] + connections = [ + # Make PING.OUTPUT available on a topic ezmsg_attach.py + ( system.PING.OUTPUT, 'GLOBAL_PING_TOPIC' ), + ] ) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 9647da10..7ca5b98e 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -229,7 +229,10 @@ async def create_graph_context() -> GraphContext: ).result() if graph_context._graph_server is None: - logger.info(f'Connected to GraphServer @ {graph_context.graph_address}') + address = graph_context.graph_address + if address is None: + address = GraphService.default_address() + logger.info(f'Connected to GraphServer @ {address}') else: logger.info(f'Spawned LOCAL GraphServer @ {graph_context.graph_address}') diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index b9ad2b9a..1935a7f3 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -127,7 +127,7 @@ async def create( name = f'chan-{chan.id}: _publisher_connection' ) - logger.debug(f'created channel {chan.id=} {pub_id=}') + logger.debug(f'created channel {chan.id=} {pub_id=} {pub_address=}') return chan diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index c5bfee1f..00a77d0c 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -121,7 +121,7 @@ async def create( pub.id, pub.id, None, pub._graph_address ) - logger.debug(f'created pub {pub.id=} {topic=}') + logger.debug(f'created pub {pub.id=} {topic=} {channel_server_address=}') return pub From 3cb550b10f4b61de8a965f5e50880c09abb68745 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 10:38:47 -0400 Subject: [PATCH 34/88] working around niche race condition --- src/ezmsg/core/pubclient.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 00a77d0c..5087fa97 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -86,18 +86,16 @@ async def create( reader, writer = await graph_service.open_connection() shm = await graph_service.create_shm(num_buffers, buf_size) - start_port = int( - os.getenv(PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT) - ) - - sock = create_socket(host, port, start_port=start_port) - writer.write(Command.PUBLISH.value) writer.write(encode_str(topic)) pub_id = UUID(await read_str(reader)) pub = cls(pub_id, topic, shm, graph_address, num_buffers, **kwargs) + start_port = int( + os.getenv(PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT) + ) + sock = create_socket(host, port, start_port=start_port) server = await asyncio.start_server(pub._channel_connect, sock=sock) pub._connection_task = asyncio.create_task( pub._serve_channels(server), From 6b9a6ee0498a1c3e1eae3f8e30b4a5246dd27e46 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 11:38:51 -0400 Subject: [PATCH 35/88] addressed socket bind race condition on linux --- src/ezmsg/core/netprotocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index cef6729e..da56734c 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -189,12 +189,12 @@ def create_socket( ignore_ports: typing.List[int] = RESERVED_PORTS, ) -> socket.socket: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if host is None: host = DEFAULT_HOST if port is not None: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) else: @@ -208,10 +208,10 @@ def create_socket( break except OSError: pass + port += 1 if not bound: raise IOError("Failed to bind socket; no free ports") - sock.setblocking(False) return sock From 6643d7c375db33ee06ad9a9bd0a214b9e20ed254 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 11:43:05 -0400 Subject: [PATCH 36/88] comments to document REUSEADDR behavior --- src/ezmsg/core/netprotocol.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index da56734c..9d78e22c 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -188,12 +188,13 @@ def create_socket( max_port: int = 65535, ignore_ports: typing.List[int] = RESERVED_PORTS, ) -> socket.socket: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if host is None: host = DEFAULT_HOST if port is not None: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # set REUSEADDR if a port is explicitly requested; leads to quick server restart on same port sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) @@ -203,6 +204,8 @@ def create_socket( while port <= max_port: if port not in ignore_ports: try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # setting REUSEADDR during portscan can lead to race conditions during bind on Linux sock.bind((host, port)) bound = True break From 37506b7656d5696eb927e95cc753c0cfe1641844 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 11:46:26 -0400 Subject: [PATCH 37/88] explicit close of socket resource --- src/ezmsg/core/netprotocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 9d78e22c..1c49d0b6 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -203,13 +203,14 @@ def create_socket( port = start_port while port <= max_port: if port not in ignore_ports: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # setting REUSEADDR during portscan can lead to race conditions during bind on Linux sock.bind((host, port)) bound = True break except OSError: + sock.close() pass port += 1 From 716140348af98c124865d205177e0a5bb50fd7c0 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 12:10:25 -0400 Subject: [PATCH 38/88] disable nagles algorithm --- src/ezmsg/core/netprotocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 1c49d0b6..a304331f 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -196,6 +196,7 @@ def create_socket( sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # set REUSEADDR if a port is explicitly requested; leads to quick server restart on same port sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.bind((host, port)) else: @@ -206,6 +207,7 @@ def create_socket( sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: # setting REUSEADDR during portscan can lead to race conditions during bind on Linux + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.bind((host, port)) bound = True break From cb15e57cefa992f3a6b89b40417bbcbe1040cd4b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 13:51:11 -0400 Subject: [PATCH 39/88] fixed race condition on channel registration --- src/ezmsg/core/backend.py | 2 +- src/ezmsg/core/messagechannel.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 7ca5b98e..19a377d6 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -234,7 +234,7 @@ async def create_graph_context() -> GraphContext: address = GraphService.default_address() logger.info(f'Connected to GraphServer @ {address}') else: - logger.info(f'Spawned LOCAL GraphServer @ {graph_context.graph_address}') + logger.info(f'Spawned GraphServer @ {graph_context.graph_address}') execution_context.create_processes( graph_address=graph_context.graph_address, diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 1935a7f3..c55b4b1d 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -281,12 +281,12 @@ def register_client(self, client_id: UUID, queue: NotificationQueue | None = Non def unregister_client(self, client_id: UUID) -> None: queue = self.clients[client_id] + # queue is only 'None' if this client is a local publisher if queue is not None: for _ in range(queue.qsize()): pub_id, msg_id = queue.get_nowait() - if pub_id == self.pub_id: - continue - queue.put_nowait((pub_id, msg_id)) + if pub_id != self.pub_id: + queue.put_nowait((pub_id, msg_id)) self.backpressure.free(client_id) @@ -323,14 +323,14 @@ async def get( create: bool = False ) -> _Channel: graph_address = _ensure_address(graph_address) - channels = self._registry.get(graph_address, dict()) - channel = channels.get(pub_id, None) + channel = self._registry.get(graph_address, dict()).get(pub_id, None) if create and channel is None: channel = await _Channel.create(pub_id, graph_address) + channels = self._registry.get(graph_address, dict()) channels[pub_id] = channel self._registry[graph_address] = channels if channel is None: - raise KeyError("channel does not exist") + raise KeyError(f"channel {pub_id=} {graph_address=} does not exist") return channel async def register( From 821c7e385a8859df5d0ba1f0848876baf9bd254f Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 17 Sep 2025 16:25:28 -0400 Subject: [PATCH 40/88] reverting buffer stride calculation change --- src/ezmsg/core/shm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 89f9c7e0..1e02bb5b 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -20,7 +20,6 @@ """ ezmsg shared memory format: - TODO: [UUID] [ UINT64 -- n_buffers ] [ UINT64 -- buf_size ] [ buf_size - 16 -- buf0 data_block ] @@ -74,7 +73,7 @@ def __init__(self, name: str) -> None: with self._shm.buf[8:16] as buf_size_mem: self.buf_size = bytes_to_uint(buf_size_mem) - buf_starts = [buf_idx * (self.buf_size + 16) for buf_idx in range(self.num_buffers)] + buf_starts = [buf_idx * self.buf_size for buf_idx in range(self.num_buffers)] buf_stops = [buf_start + self.buf_size for buf_start in buf_starts] buf_data_block_starts = [buf_start + 16 for buf_start in buf_starts] From 6b1277815addf455c569f827fe0922a60c21e5d2 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 19 Sep 2025 16:30:00 -0400 Subject: [PATCH 41/88] much more comprehensive perf test with analysis: ezmsg-perf --- pyproject.toml | 4 +- src/ezmsg/util/perf/__init__.py | 0 src/ezmsg/util/perf/analysis.py | 113 ++++++++++ src/ezmsg/util/perf/command.py | 55 +++++ src/ezmsg/util/perf/eval.py | 91 ++++++++ src/ezmsg/util/perf/util.py | 356 ++++++++++++++++++++++++++++++++ src/ezmsg/util/perf_test.py | 220 -------------------- 7 files changed, 618 insertions(+), 221 deletions(-) create mode 100644 src/ezmsg/util/perf/__init__.py create mode 100644 src/ezmsg/util/perf/analysis.py create mode 100644 src/ezmsg/util/perf/command.py create mode 100644 src/ezmsg/util/perf/eval.py create mode 100644 src/ezmsg/util/perf/util.py delete mode 100644 src/ezmsg/util/perf_test.py diff --git a/pyproject.toml b/pyproject.toml index df1f469a..961271b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ test = [ "pytest>=8.4.1", "pytest-asyncio>=1.1.0", "pytest-cov>=6.2.1", - "xarray>=2023.1.0;python_version<'3.13'" + "xarray>=2023.1.0;python_version<'3.13'", + "psutil>=7.1.0", ] docs = [ "ezmsg-sigproc>=2.2.0", @@ -44,6 +45,7 @@ docs = [ [project.scripts] ezmsg = "ezmsg.core.command:cmdline" +ezmsg-perf = "ezmsg.util.perf.command:command" [project.optional-dependencies] axisarray = [ diff --git a/src/ezmsg/util/perf/__init__.py b/src/ezmsg/util/perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py new file mode 100644 index 00000000..48dd800e --- /dev/null +++ b/src/ezmsg/util/perf/analysis.py @@ -0,0 +1,113 @@ +import json +import typing +import dataclasses + +from pathlib import Path + +from ..messagecodec import MessageDecoder +from .util import ( + TestEnvironmentInfo, + TestParameters, + Metrics, +) + +import ezmsg.core as ez + +try: + import xarray as xr +except ImportError: + ez.logger.error('ezmsg perf analysis requires xarray') + raise + +try: + import numpy as np +except ImportError: + ez.logger.error('ezmsg perf analysis requires numpy') + raise + + +def load_perf(perf: Path) -> xr.Dataset: + + params: typing.List[TestParameters] = [] + results: typing.List[Metrics] = [] + + with open(perf, 'r') as perf_f: + info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) + + for line in perf_f: + obj = json.loads(line, cls = MessageDecoder) + params.append(obj['params']) + results.append(obj['results']) + + n_clients_axis = list(sorted(set([p.n_clients for p in params]))) + msg_size_axis = list(sorted(set([p.msg_size for p in params]))) + comms_axis = list(sorted(set([p.comms for p in params]))) + config_axis = list(sorted(set([p.config for p in params]))) + + dims = ['n_clients', 'msg_size', 'comms', 'config'] + coords = { + 'n_clients': n_clients_axis, + 'msg_size': msg_size_axis, + 'comms': comms_axis, + 'config': config_axis + } + + data_vars = {} + for field in dataclasses.fields(Metrics): + m = np.zeros(( + len(n_clients_axis), + len(msg_size_axis), + len(comms_axis), + len(config_axis) + )) * np.nan + for p, r in zip(params, results): + m[ + n_clients_axis.index(p.n_clients), + msg_size_axis.index(p.msg_size), + comms_axis.index(p.comms), + config_axis.index(p.config) + ] = getattr(r, field.name) + data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) + + dataset = xr.Dataset(data_vars, attrs = dict(info = info)) + return dataset + +@dataclasses.dataclass +class SummaryArgs: + perf: Path + baseline: Path | None + +def summary(args: SummaryArgs): + + perf_dataset = load_perf(args.perf) + + for config, config_ds in perf_dataset.groupby('config'): + for comms, comms_ds in config_ds.groupby('comms'): + print(f'{config}: {comms}') + print(comms_ds.sample_rate.data) + + if args.baseline is not None: + baseline_dataset = load_perf(args.baseline) + + +def command() -> None: + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "perf", + type=lambda x: Path(x), + help="perf test", + ) + + parser.add_argument( + "--baseline", "-b", + type=lambda x: Path(x), + default = None, + help="baseline perf test for comparison" + ) + + args = parser.parse_args(namespace=SummaryArgs) + + summary(args) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py new file mode 100644 index 00000000..5269a1d0 --- /dev/null +++ b/src/ezmsg/util/perf/command.py @@ -0,0 +1,55 @@ +from pathlib import Path + +from .analysis import summary, SummaryArgs +from .eval import perf_eval, PerfEvalArgs + +def command() -> None: + import argparse + + parser = argparse.ArgumentParser(description = 'ezmsg perf test utility') + subparsers = parser.add_subparsers(dest="command", required=True) + + p_run = subparsers.add_parser("run", help="run performance test") + p_run.add_argument( + "--duration", + type=float, + default=2.0, + help="individual test duration in seconds (default = 2.0)", + ) + p_run.add_argument( + "--num-buffers", + type=int, + default=32, + help="shared memory buffers (default = 32)", + ) + + p_run.set_defaults(_handler=lambda ns: perf_eval( + PerfEvalArgs( + duration = ns.duration, + num_buffers = ns.num_buffers + ) + )) + + p_summary = subparsers.add_parser("summary", help = "summarise performance results") + p_summary.add_argument( + "perf", + type=Path, + help="perf test", + ) + p_summary.add_argument( + "--baseline", + "-b", + type=Path, + default=None, + help="baseline perf test for comparison", + ) + + p_summary.set_defaults(_handler=lambda ns: summary( + SummaryArgs( + perf = ns.perf, + baseline = ns.baseline + ) + )) + + ns = parser.parse_args() + ns._handler(ns) diff --git a/src/ezmsg/util/perf/eval.py b/src/ezmsg/util/perf/eval.py new file mode 100644 index 00000000..894c45db --- /dev/null +++ b/src/ezmsg/util/perf/eval.py @@ -0,0 +1,91 @@ +import json +import datetime +import itertools + +from dataclasses import dataclass + +from ..messagecodec import MessageEncoder +from .util import ( + TestEnvironmentInfo, + TestParameters, + perform_test, + Communication, + CONFIGS, +) + +import ezmsg.core as ez + +def get_datestamp() -> str: + return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + +@dataclass +class PerfEvalArgs: + duration: float + num_buffers: int + +def perf_eval(args: PerfEvalArgs) -> None: + + msg_sizes = [2 ** exp for exp in range(4, 25, 4)] + n_clients = [2 ** exp for exp in range(0, 6)] + comms = [c for c in Communication] + + test_list = list(itertools.product(msg_sizes, n_clients, CONFIGS, comms)) + + with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: + + out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + + for test_idx, (msg_size, n_clients, config, comms) in enumerate(test_list): + + ez.logger.info(f"RUNNING TEST {test_idx + 1} / {len(test_list)} ({(test_idx / len(test_list)) * 100.0:0.2f} %)") + + params = TestParameters( + msg_size = msg_size, + n_clients = n_clients, + config = config.__name__, + comms = comms.value, + duration = args.duration, + num_buffers = args.num_buffers + ) + + results = perform_test( + n_clients = n_clients, + duration = args.duration, + msg_size = msg_size, + buffers = args.num_buffers, + comms = comms, + config = config, + ) + + output = dict( + params = params, + results = results + ) + + out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + + +def command() -> None: + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--duration", + type=float, + default=2.0, + help="How long to run each load test (seconds) (default = 2.0)", + ) + + parser.add_argument( + "--num-buffers", + type=int, + default=32, + help="shared memory buffers (default = 32)" + ) + + args = parser.parse_args(namespace=PerfEvalArgs) + + perf_eval(args) + + diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py new file mode 100644 index 00000000..8883591b --- /dev/null +++ b/src/ezmsg/util/perf/util.py @@ -0,0 +1,356 @@ +import asyncio +import dataclasses +import datetime +import os +import platform +import time +import typing +import enum +import sys +import subprocess + +import ezmsg.core as ez + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +try: + import psutil +except ImportError: + ez.logger.error("ezmsg perf requires psutil") + raise + + +def collect( + components: typing.Optional[typing.Mapping[str, ez.Component]] = None, + network: ez.NetworkDefinition = (), + process_components: typing.Collection[ez.Component] | None = None, + **components_kwargs: ez.Component, +) -> ez.Collection: + """ collect a grouping of pre-configured components into a new "Collection" """ + from ezmsg.core.util import either_dict_or_kwargs + + components = either_dict_or_kwargs(components, components_kwargs, "collect") + if components is None: + raise ValueError("Must supply at least one component to run") + + out = ez.Collection() + for name, comp in components.items(): + comp._set_name(name) + out._components = components # FIXME: Component._components should be typehinted as a Mapping + out.network = lambda: network + out.process_components = lambda: (out,) if process_components is None else process_components + return out + + +@dataclasses.dataclass +class Metrics: + num_msgs: int + sample_rate: float + latency_mean: float + latency_total: float + data_rate: float + + +class LoadTestSettings(ez.Settings): + duration: float + dynamic_size: int + buffers: int + force_tcp: bool + + +@dataclasses.dataclass +class LoadTestSample: + _timestamp: float + counter: int + dynamic_data: np.ndarray + key: str + + +class LoadTestSender(ez.Unit): + OUTPUT = ez.OutputStream(LoadTestSample) + SETTINGS = LoadTestSettings + + async def initialize(self) -> None: + self.running = True + self.counter = 0 + + self.OUTPUT.num_buffers = self.SETTINGS.buffers + self.OUTPUT.force_tcp = self.SETTINGS.force_tcp + + @ez.publisher(OUTPUT) + async def publish(self) -> typing.AsyncGenerator: + ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") + start_time = time.perf_counter() + while self.running: + current_time = time.perf_counter() + if current_time - start_time >= self.SETTINGS.duration: + break + + yield ( + self.OUTPUT, + LoadTestSample( + _timestamp=time.perf_counter(), + counter=self.counter, + dynamic_data=np.zeros( + int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 + ), + key = self.name, + ), + ) + self.counter += 1 + ez.logger.info("Exiting publish") + raise ez.Complete + +class LoadTestSource(LoadTestSender): + async def shutdown(self) -> None: + self.running = False + ez.logger.info(f"Samples sent: {self.counter}") + + +class LoadTestRelay(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + OUTPUT = ez.OutputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy = True) + @ez.publisher(OUTPUT) + async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator: + yield self.OUTPUT, msg + + +class LoadTestReceiverState(ez.State): + # Tuples of sent timestamp, received timestamp, counter, dynamic size + received_data: typing.List[typing.Tuple[float, float, int]] = dataclasses.field( + default_factory=list + ) + counters: typing.Dict[str, int] = dataclasses.field(default_factory=dict) + + +class LoadTestReceiver(ez.Unit): + INPUT = ez.InputStream(LoadTestSample) + SETTINGS = LoadTestSettings + STATE = LoadTestReceiverState + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + counter = self.STATE.counters.get(sample.key, -1) + if sample.counter != counter + 1: + ez.logger.warning( + f"{sample.counter - counter - 1} samples skipped!" + ) + self.STATE.received_data.append( + (sample._timestamp, time.perf_counter(), sample.counter) + ) + self.STATE.counters[sample.key] = sample.counter + + +class LoadTestSink(LoadTestReceiver): + + @ez.task + async def terminate(self) -> None: + ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") + + # Wait for the duration of the load test + await asyncio.sleep(self.SETTINGS.duration) + raise ez.NormalTermination + + +### TEST CONFIGURATIONS + +@dataclasses.dataclass +class ConfigSettings: + n_clients: int + settings: LoadTestSettings + source: LoadTestSource + sink: LoadTestSink + +Configuration = typing.Tuple[typing.Iterable[ez.Component], ez.NetworkDefinition] +Configurator = typing.Callable[[ConfigSettings], Configuration] + +def fanout(config: ConfigSettings) -> Configuration: + """ one pub to many subs """ + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + subs = [LoadTestReceiver(config.settings) for _ in range(config.n_clients)] + for sub in subs: + connections.append((config.source.OUTPUT, sub.INPUT)) + + return subs, connections + +def fanin(config: ConfigSettings) -> Configuration: + """ many pubs to one sub """ + connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] + pubs = [LoadTestSender(config.settings) for _ in range(config.n_clients)] + for pub in pubs: + connections.append((pub.OUTPUT, config.sink.INPUT)) + return pubs, connections + + +def relay(config: ConfigSettings) -> Configuration: + """ one pub to one sub through many relays """ + connections: ez.NetworkDefinition = [] + + relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] + connections.append((config.source.OUTPUT, relays[0].INPUT)) + for from_relay, to_relay in zip(relays[:-1], relays[1:]): + connections.append((from_relay.OUTPUT, to_relay.INPUT)) + connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + + return relays, connections + +CONFIGS: typing.Iterable[Configurator] = [fanin, fanout, relay] + +class Communication(enum.StrEnum): + LOCAL = "local" + SHM = "shm" + SHM_SPREAD = "shm_spread" + TCP = "tcp" + TCP_SPREAD = "tcp_spread" + +def perform_test( + n_clients: int, + duration: float, + msg_size: int, + buffers: int, + comms: Communication, + config: Configurator +) -> Metrics: + + settings = LoadTestSettings( + dynamic_size = int(msg_size), + duration = duration, + buffers = buffers, + force_tcp = (comms == Communication.TCP), + ) + + source = LoadTestSource(settings) + sink = LoadTestSink(settings) + + components: typing.Mapping[str, ez.Component] = dict( + SINK = sink, + ) + + clients, connections = config(ConfigSettings(n_clients, settings, source, sink)) + + # The 'sink' MUST remain in this process for us to pull its state. + process_components: typing.Iterable[ez.Component] = [] + if comms == Communication.LOCAL: + # Every component in the same process (this one) + components["SOURCE"] = source + for i, client in enumerate(clients): + components[f"CLIENT_{i+1}"] = client + + else: + + if comms in (Communication.SHM_SPREAD, Communication.TCP_SPREAD): + # Every component in its own process. + components["SOURCE"] = source + process_components.append(source) + for i, client in enumerate(clients): + components[f'CLIENT_{i+1}'] = client + process_components.append(client) + + else: + # All clients and the source in ONE other process. + collect_comps: typing.Mapping[str, ez.Component] = dict() + collect_comps["SOURCE"] = source + for i, client in enumerate(clients): + collect_comps[f"CLIENT_{i+1}"] = client + proc_collection = collect(components = collect_comps) + components["PROC"] = proc_collection + process_components = [proc_collection] + + ez.run( + components = components, + connections = connections, + process_components = process_components, + ) + + return calculate_metrics(sink) + + +def calculate_metrics(sink: LoadTestSink) -> Metrics: + + # Log some useful summary statistics + min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) + max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) + total_latency = abs( + sum( + receive_timestamp - send_timestamp + for send_timestamp, receive_timestamp, _ in sink.STATE.received_data + ) + ) + + counters = list(sorted(t[2] for t in sink.STATE.received_data)) + dropped_samples = sum( + [max((x1 - x0) - 1, 0) for x1, x0 in zip(counters[1:], counters[:-1])] + ) + + num_samples = len(sink.STATE.received_data) + ez.logger.info(f"Samples received: {num_samples}") + sample_rate = num_samples / (max_timestamp - min_timestamp) + ez.logger.info(f"Sample rate: {sample_rate} Hz") + latency_mean = total_latency / num_samples + ez.logger.info(f"Mean latency: {latency_mean} s") + ez.logger.info(f"Total latency: {total_latency} s") + + total_data = num_samples * sink.SETTINGS.dynamic_size + data_rate = total_data / (max_timestamp - min_timestamp) + ez.logger.info(f"Data rate: {data_rate * 1e-6} MB/s") + + if dropped_samples: + ez.logger.error( + f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", + ) + + return Metrics( + num_msgs = num_samples, + sample_rate = sample_rate, + latency_mean = latency_mean, + latency_total = total_latency, + data_rate = data_rate + ) + + +@dataclasses.dataclass +class TestParameters: + msg_size: int + n_clients: int + config: str + comms: str + duration: float + num_buffers: int + +def _git_commit() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +def _git_branch() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +@dataclasses.dataclass +class TestEnvironmentInfo: + ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) + numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) + python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) + os: str = dataclasses.field(default_factory=lambda: platform.system()) + os_version: str = dataclasses.field(default_factory=lambda: platform.version()) + machine: str = dataclasses.field(default_factory=lambda: platform.machine()) + processor: str = dataclasses.field(default_factory=lambda: platform.processor()) + cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) + cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) + memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) + start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) + git_commit: str = dataclasses.field(default_factory=_git_commit) + git_branch: str = dataclasses.field(default_factory=_git_branch) \ No newline at end of file diff --git a/src/ezmsg/util/perf_test.py b/src/ezmsg/util/perf_test.py deleted file mode 100644 index b7536b5e..00000000 --- a/src/ezmsg/util/perf_test.py +++ /dev/null @@ -1,220 +0,0 @@ -import asyncio -import dataclasses -import datetime -import os -import platform -import time - -from typing import List, Tuple, AsyncGenerator - -import ezmsg.core as ez - -# We expect this test to generate LOTS of backpressure warnings -# PERF_LOGLEVEL = os.environ.get("EZMSG_LOGLEVEL", "ERROR") -# ez.logger.setLevel(PERF_LOGLEVEL) - -PLATFORM = { - "Darwin": "mac", - "Linux": "linux", - "Windows": "win", -}[platform.system()] -SAMPLE_SUMMARY_DATASET_PREFIX = "sample_summary" -COUNT_DATASET_NAME = "count" - - -try: - import numpy as np -except ImportError: - ez.logger.error("This test requires Numpy to run.") - raise - - -def get_datestamp() -> str: - return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - -class LoadTestSettings(ez.Settings): - duration: float = 30.0 - dynamic_size: int = 8 - buffers: int = 32 - - -@dataclasses.dataclass -class LoadTestSample: - _timestamp: float - counter: int - dynamic_data: np.ndarray - - -class LoadTestPublisher(ez.Unit): - OUTPUT = ez.OutputStream(LoadTestSample) - SETTINGS = LoadTestSettings - - async def initialize(self) -> None: - self.running = True - self.counter = 0 - self.OUTPUT.num_buffers = self.SETTINGS.buffers - - @ez.publisher(OUTPUT) - async def publish(self) -> AsyncGenerator: - ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.time() - while self.running: - current_time = time.time() - if current_time - start_time >= self.SETTINGS.duration: - break - - yield ( - self.OUTPUT, - LoadTestSample( - _timestamp=time.time(), - counter=self.counter, - dynamic_data=np.zeros( - int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 - ), - ), - ) - self.counter += 1 - ez.logger.info("Exiting publish") - raise ez.Complete - - async def shutdown(self) -> None: - self.running = False - ez.logger.info(f"Samples sent: {self.counter}") - - -class LoadTestSubscriberState(ez.State): - # Tuples of sent timestamp, received timestamp, counter, dynamic size - received_data: List[Tuple[float, float, int]] = dataclasses.field( - default_factory=list - ) - counter: int = -1 - - -class LoadTestSubscriber(ez.Unit): - INPUT = ez.InputStream(LoadTestSample) - SETTINGS = LoadTestSettings - STATE = LoadTestSubscriberState - - @ez.subscriber(INPUT, zero_copy=True) - async def receive(self, sample: LoadTestSample) -> None: - if sample.counter != self.STATE.counter + 1: - ez.logger.warning( - f"{sample.counter - self.STATE.counter - 1} samples skipped!" - ) - self.STATE.received_data.append( - (sample._timestamp, time.time(), sample.counter) - ) - self.STATE.counter = sample.counter - - @ez.task - async def log_result(self) -> None: - ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") - - # Wait for the duration of the load test - await asyncio.sleep(self.SETTINGS.duration) - # logger.info(f"STATE = {self.STATE.received_data}") - - # Log some useful summary statistics - min_timestamp = min(timestamp for timestamp, _, _ in self.STATE.received_data) - max_timestamp = max(timestamp for timestamp, _, _ in self.STATE.received_data) - total_latency = abs( - sum( - receive_timestamp - send_timestamp - for send_timestamp, receive_timestamp, _ in self.STATE.received_data - ) - ) - - counters = list(sorted(t[2] for t in self.STATE.received_data)) - dropped_samples = sum( - [(x1 - x0) - 1 for x1, x0 in zip(counters[1:], counters[:-1])] - ) - - num_samples = len(self.STATE.received_data) - ez.logger.info(f"Samples received: {num_samples}") - ez.logger.info( - f"Sample rate: {num_samples / (max_timestamp - min_timestamp)} Hz" - ) - ez.logger.info(f"Mean latency: {total_latency / num_samples} s") - ez.logger.info(f"Total latency: {total_latency} s") - - total_data = num_samples * self.SETTINGS.dynamic_size - ez.logger.info( - f"Data rate: {total_data / (max_timestamp - min_timestamp) * 1e-6} MB/s" - ) - ez.logger.info( - f"Dropped samples: {dropped_samples} ({dropped_samples / (dropped_samples + num_samples)}%)", - ) - - raise ez.NormalTermination - - -class LoadTest(ez.Collection): - SETTINGS = LoadTestSettings - - PUBLISHER = LoadTestPublisher() - SUBSCRIBER = LoadTestSubscriber() - - def configure(self) -> None: - self.PUBLISHER.apply_settings(self.SETTINGS) - self.SUBSCRIBER.apply_settings(self.SETTINGS) - - def network(self) -> ez.NetworkDefinition: - return ((self.PUBLISHER.OUTPUT, self.SUBSCRIBER.INPUT),) - - def process_components(self): - return ( - self.PUBLISHER, - self.SUBSCRIBER, - ) - - -def get_time() -> float: - # time.perf_counter() isn't system-wide on Windows Python 3.6: - # https://bugs.python.org/issue37205 - return time.time() if PLATFORM == "win" else time.perf_counter() - - -def test_performance(duration, size, buffers) -> None: - ez.logger.info(f"Running load test for dynamic size: {size} bytes") - system = LoadTest( - LoadTestSettings(dynamic_size=int(size), duration=duration, buffers=buffers) - ) - ez.run(SYSTEM=system) - - -def run_many_dynamic_sizes(duration, buffers) -> None: - for exp in range(5, 22, 4): - test_performance(duration, 2**exp, buffers) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--many-dynamic-sizes", - action="store_true", - help="Run load test for many dynamic sizes", - ) - parser.add_argument( - "--duration", - type=int, - default=2, - help="How long to run the load test (seconds)", - ) - parser.add_argument( - "--num-buffers", type=int, default=32, help="Shared memory buffers" - ) - - class Args: - many_dynamic_sizes: bool - duration: int - num_buffers: int - - args = parser.parse_args(namespace=Args) - - if args.many_dynamic_sizes: - run_many_dynamic_sizes(args.duration, args.num_buffers) - else: - test_performance(args.duration, 8, args.num_buffers) From bdd405275e1e264516a138f87dea0f2e9b2fe2d1 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 08:50:38 -0400 Subject: [PATCH 42/88] less aggressive test strategy --- src/ezmsg/util/perf/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/util/perf/eval.py b/src/ezmsg/util/perf/eval.py index 894c45db..f7f7e27b 100644 --- a/src/ezmsg/util/perf/eval.py +++ b/src/ezmsg/util/perf/eval.py @@ -25,8 +25,8 @@ class PerfEvalArgs: def perf_eval(args: PerfEvalArgs) -> None: - msg_sizes = [2 ** exp for exp in range(4, 25, 4)] - n_clients = [2 ** exp for exp in range(0, 6)] + msg_sizes = [2 ** exp for exp in range(4, 25, 8)] + n_clients = [2 ** exp for exp in range(0, 6, 2)] comms = [c for c in Communication] test_list = list(itertools.product(msg_sizes, n_clients, CONFIGS, comms)) From ae0fb5dfca8924d6816d6d676b88a261de42c111 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 10:00:16 -0400 Subject: [PATCH 43/88] slight refactor --- src/ezmsg/util/perf/analysis.py | 60 ++++++++--------- src/ezmsg/util/perf/command.py | 53 ++------------- src/ezmsg/util/perf/envinfo.py | 83 ++++++++++++++++++++++++ src/ezmsg/util/perf/{util.py => impl.py} | 41 ------------ src/ezmsg/util/perf/{eval.py => run.py} | 42 ++++++------ 5 files changed, 139 insertions(+), 140 deletions(-) create mode 100644 src/ezmsg/util/perf/envinfo.py rename src/ezmsg/util/perf/{util.py => impl.py} (84%) rename src/ezmsg/util/perf/{eval.py => run.py} (72%) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 48dd800e..0be97744 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -1,12 +1,13 @@ import json import typing import dataclasses +import argparse from pathlib import Path from ..messagecodec import MessageDecoder -from .util import ( - TestEnvironmentInfo, +from .envinfo import TestEnvironmentInfo +from .impl import ( TestParameters, Metrics, ) @@ -72,42 +73,41 @@ def load_perf(perf: Path) -> xr.Dataset: dataset = xr.Dataset(data_vars, attrs = dict(info = info)) return dataset -@dataclasses.dataclass -class SummaryArgs: - perf: Path - baseline: Path | None -def summary(args: SummaryArgs): +def summary(perf_path: Path, baseline_path: Path | None) -> None: + """ print perf test results and comparisons to the console """ - perf_dataset = load_perf(args.perf) + perf = load_perf(perf_path) + info = perf.attrs['info'] + if baseline_path is not None: + baseline = load_perf(baseline_path) + perf = (perf / baseline) * 100.0 - for config, config_ds in perf_dataset.groupby('config'): - for comms, comms_ds in config_ds.groupby('comms'): - print(f'{config}: {comms}') - print(comms_ds.sample_rate.data) + print(info) - if args.baseline is not None: - baseline_dataset = load_perf(args.baseline) + for _, config_ds in perf.groupby('config'): + for _, comms_ds in config_ds.groupby('comms'): + print(comms_ds.squeeze().to_dataframe()) + print("\n") + print("\n") -def command() -> None: - import argparse - - parser = argparse.ArgumentParser() - - parser.add_argument( +def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: + p_summary = subparsers.add_parser("summary", help = "summarize performance results") + p_summary.add_argument( "perf", - type=lambda x: Path(x), + type=Path, help="perf test", ) - - parser.add_argument( - "--baseline", "-b", - type=lambda x: Path(x), - default = None, - help="baseline perf test for comparison" + p_summary.add_argument( + "--baseline", + "-b", + type=Path, + default=None, + help="baseline perf test for comparison", ) - args = parser.parse_args(namespace=SummaryArgs) - - summary(args) + p_summary.set_defaults(_handler=lambda ns: summary( + perf_path = ns.perf, + baseline_path = ns.baseline + )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index 5269a1d0..a100863d 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -1,55 +1,14 @@ -from pathlib import Path +import argparse -from .analysis import summary, SummaryArgs -from .eval import perf_eval, PerfEvalArgs +from .analysis import setup_summary_cmdline +from .run import setup_run_cmdline def command() -> None: - import argparse - parser = argparse.ArgumentParser(description = 'ezmsg perf test utility') subparsers = parser.add_subparsers(dest="command", required=True) - p_run = subparsers.add_parser("run", help="run performance test") - p_run.add_argument( - "--duration", - type=float, - default=2.0, - help="individual test duration in seconds (default = 2.0)", - ) - p_run.add_argument( - "--num-buffers", - type=int, - default=32, - help="shared memory buffers (default = 32)", - ) - - p_run.set_defaults(_handler=lambda ns: perf_eval( - PerfEvalArgs( - duration = ns.duration, - num_buffers = ns.num_buffers - ) - )) - - p_summary = subparsers.add_parser("summary", help = "summarise performance results") - p_summary.add_argument( - "perf", - type=Path, - help="perf test", - ) - p_summary.add_argument( - "--baseline", - "-b", - type=Path, - default=None, - help="baseline perf test for comparison", - ) - - p_summary.set_defaults(_handler=lambda ns: summary( - SummaryArgs( - perf = ns.perf, - baseline = ns.baseline - ) - )) - + setup_run_cmdline(subparsers) + setup_summary_cmdline(subparsers) + ns = parser.parse_args() ns._handler(ns) diff --git a/src/ezmsg/util/perf/envinfo.py b/src/ezmsg/util/perf/envinfo.py new file mode 100644 index 00000000..66fff18b --- /dev/null +++ b/src/ezmsg/util/perf/envinfo.py @@ -0,0 +1,83 @@ +import dataclasses +import datetime +import platform +import typing +import sys +import subprocess + +import ezmsg.core as ez + +try: + import numpy as np +except ImportError: + ez.logger.error("ezmsg perf requires numpy") + raise + +try: + import psutil +except ImportError: + ez.logger.error("ezmsg perf requires psutil") + raise + + +def _git_commit() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +def _git_branch() -> str: + try: + return subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except: + return "unknown" + +@dataclasses.dataclass +class TestEnvironmentInfo: + ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) + numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) + python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) + os: str = dataclasses.field(default_factory=lambda: platform.system()) + os_version: str = dataclasses.field(default_factory=lambda: platform.version()) + machine: str = dataclasses.field(default_factory=lambda: platform.machine()) + processor: str = dataclasses.field(default_factory=lambda: platform.processor()) + cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) + cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) + memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) + start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) + git_commit: str = dataclasses.field(default_factory=_git_commit) + git_branch: str = dataclasses.field(default_factory=_git_branch) + + def __str__(self) -> str: + fields = dataclasses.asdict(self) + width = max(len(k) for k in fields) + lines = ["TestEnvironmentInfo:"] + for key, value in fields.items(): + lines.append(f" {key.ljust(width)} : {value}") + return "\n".join(lines) + + def diff(self, other: "TestEnvironmentInfo") -> typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]: + """Return a structured diff: {field: (self_value, other_value)} for changed fields.""" + a = dataclasses.asdict(self) + b = dataclasses.asdict(other) + keys = set(a) | set(b) + return {k: (a.get(k), b.get(k)) for k in keys if a.get(k) != b.get(k)} + + +def format_env_diff(diffs: typing.Dict[str, typing.Tuple[typing.Any, typing.Any]]) -> str: + """Pretty-print the structured diff in the same aligned style.""" + if not diffs: + return "No differences." + width = max(len(k) for k in diffs) + lines = ["Differences in TestEnvironmentInfo:"] + for k in sorted(diffs): + left, right = diffs[k] + lines.append(f" {k.ljust(width)} : {left} != {right}") + return "\n".join(lines) + +def diff_envs(a: TestEnvironmentInfo, b: TestEnvironmentInfo) -> str: + return format_env_diff(a.diff(b)) \ No newline at end of file diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/impl.py similarity index 84% rename from src/ezmsg/util/perf/util.py rename to src/ezmsg/util/perf/impl.py index 8883591b..58357b8c 100644 --- a/src/ezmsg/util/perf/util.py +++ b/src/ezmsg/util/perf/impl.py @@ -1,13 +1,9 @@ import asyncio import dataclasses -import datetime import os -import platform import time import typing import enum -import sys -import subprocess import ezmsg.core as ez @@ -17,12 +13,6 @@ ez.logger.error("ezmsg perf requires numpy") raise -try: - import psutil -except ImportError: - ez.logger.error("ezmsg perf requires psutil") - raise - def collect( components: typing.Optional[typing.Mapping[str, ez.Component]] = None, @@ -323,34 +313,3 @@ class TestParameters: duration: float num_buffers: int -def _git_commit() -> str: - try: - return subprocess.check_output( - ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL - ).decode().strip() - except: - return "unknown" - -def _git_branch() -> str: - try: - return subprocess.check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL - ).decode().strip() - except: - return "unknown" - -@dataclasses.dataclass -class TestEnvironmentInfo: - ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) - numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) - python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) - os: str = dataclasses.field(default_factory=lambda: platform.system()) - os_version: str = dataclasses.field(default_factory=lambda: platform.version()) - machine: str = dataclasses.field(default_factory=lambda: platform.machine()) - processor: str = dataclasses.field(default_factory=lambda: platform.processor()) - cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) - cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) - memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) - start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) - git_commit: str = dataclasses.field(default_factory=_git_commit) - git_branch: str = dataclasses.field(default_factory=_git_branch) \ No newline at end of file diff --git a/src/ezmsg/util/perf/eval.py b/src/ezmsg/util/perf/run.py similarity index 72% rename from src/ezmsg/util/perf/eval.py rename to src/ezmsg/util/perf/run.py index f7f7e27b..d36dd02c 100644 --- a/src/ezmsg/util/perf/eval.py +++ b/src/ezmsg/util/perf/run.py @@ -1,12 +1,13 @@ import json import datetime import itertools +import argparse from dataclasses import dataclass from ..messagecodec import MessageEncoder -from .util import ( - TestEnvironmentInfo, +from .envinfo import TestEnvironmentInfo +from .impl import ( TestParameters, perform_test, Communication, @@ -19,11 +20,11 @@ def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") @dataclass -class PerfEvalArgs: +class PerfRunArgs: duration: float num_buffers: int -def perf_eval(args: PerfEvalArgs) -> None: +def perf_run(args: PerfRunArgs) -> None: msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] @@ -64,28 +65,25 @@ def perf_eval(args: PerfEvalArgs) -> None: out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") +def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: -def command() -> None: - import argparse - - parser = argparse.ArgumentParser() - - parser.add_argument( + p_run = subparsers.add_parser("run", help="run performance test") + p_run.add_argument( "--duration", type=float, default=2.0, - help="How long to run each load test (seconds) (default = 2.0)", + help="individual test duration in seconds (default = 2.0)", ) - - parser.add_argument( - "--num-buffers", - type=int, - default=32, - help="shared memory buffers (default = 32)" + p_run.add_argument( + "--num-buffers", + type=int, + default=32, + help="shared memory buffers (default = 32)", ) - args = parser.parse_args(namespace=PerfEvalArgs) - - perf_eval(args) - - + p_run.set_defaults(_handler=lambda ns: perf_run( + PerfRunArgs( + duration = ns.duration, + num_buffers = ns.num_buffers + ) + )) \ No newline at end of file From e25284952b30e5e76fbd352e21cd174fdc9b1241 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 13:28:20 -0400 Subject: [PATCH 44/88] minor --- src/ezmsg/util/perf/ai.py | 111 ++++++++++++++++++++++++++++++++ src/ezmsg/util/perf/analysis.py | 38 ++++++++--- 2 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 src/ezmsg/util/perf/ai.py diff --git a/src/ezmsg/util/perf/ai.py b/src/ezmsg/util/perf/ai.py new file mode 100644 index 00000000..a212b81c --- /dev/null +++ b/src/ezmsg/util/perf/ai.py @@ -0,0 +1,111 @@ +import os +import json +import textwrap +import urllib.request + +DEFAULT_TEST_DESCRIPTION = """\ +You are analyzing performance test results for the ezmsg pub/sub system. + +Configurations (config): +- fanin: Many publishers to one subscriber +- fanout: one publisher to many subscribers +- relay: one publisher to one subscriber through many relays + +Communication strategies (comms): +- local: all subs, relays, and pubs are in the SAME process +- shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved +- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + +Variables: +- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- msg_size: nominal message size (bytes) + +Metrics: +- sample_rate: messages/sec at the sink (higher = better) +- data_rate: bytes/sec at the sink (higher = better) +- latency_mean: average send -> receive latency in seconds (lower = better) + +Task: +Summarize performance test results. Explain trade-offs by comms/config, call out anomalies/outliers. +Keep the tone concise, technical, and actionable. +Please format output such that it will display nicely in a terminal output. + +If the rest results are a "PERFORMANCE COMPARISON": +- Metrics are in percentages (100.0 = 100 percent = no change) and do not reflect the ground-truth physical units. +- Summarize key improvements/regressions +- Performance differences +/- 5 percent are likely in the noise. +""" + +def chatgpt_analyze_results( + results_text: str, + *, + prompt: str | None = None, + model: str | None = None, + max_chars: int = 120_000, + temperature: float = 0.2, +) -> str: + """ + Send results + a test description to OpenAI's Responses API and print the analysis. + + Env vars: + - OPENAI_API_KEY (required) + - OPENAI_MODEL (optional; e.g., 'gpt-4o-mini' or a newer model) + """ + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("Please set OPENAI_API_KEY in your environment.") + + model = model or os.getenv("OPENAI_MODEL", "gpt-4o-mini") + + # Keep requests reasonable in size + results_snippet = results_text if len(results_text) <= max_chars else ( + results_text[:max_chars] + "\n\n[...truncated for token budget...]" + ) + + # You can tweak the system instruction here to steer tone/format + system_instruction = "You are a senior performance engineer. Prefer precise, structured analysis." + + user_payload = textwrap.dedent(f"""\ + {prompt or DEFAULT_TEST_DESCRIPTION} + + === BEGIN RESULTS === + {results_snippet} + === END RESULTS === + """) + + body = { + "model": model, + "temperature": temperature, + "input": [ + {"role": "system", "content": [{"type": "text", "text": system_instruction}]}, + {"role": "user", "content": [{"type": "text", "text": user_payload}]}, + ], + } + + req = urllib.request.Request( + "https://api.openai.com/v1/responses", + data=json.dumps(body).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + method="POST", + ) + + with urllib.request.urlopen(req) as resp: + data = json.load(resp) + + # Robust extraction: prefer output_text; fall back to concatenating output content + text = data.get("output_text") + if not text: + parts = [] + for item in data.get("output", []) or []: + for c in item.get("content", []) or []: + if c.get("type") in ("output_text", "text") and "text" in c: + parts.append(c["text"]) + text = "\n".join(parts) if parts else json.dumps(data, indent=2) + + return text diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 0be97744..be0cb1e7 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -6,7 +6,8 @@ from pathlib import Path from ..messagecodec import MessageDecoder -from .envinfo import TestEnvironmentInfo +from .envinfo import TestEnvironmentInfo, format_env_diff +from .ai import chatgpt_analyze_results from .impl import ( TestParameters, Metrics, @@ -74,22 +75,37 @@ def load_perf(perf: Path) -> xr.Dataset: return dataset -def summary(perf_path: Path, baseline_path: Path | None) -> None: +def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> None: """ print perf test results and comparisons to the console """ + output = '' + perf = load_perf(perf_path) - info = perf.attrs['info'] + info: TestEnvironmentInfo = perf.attrs['info'] + output += str(info) + '\n\n' + + if baseline_path is not None: + output += "PERFORMANCE COMPARISON\n\n" baseline = load_perf(baseline_path) perf = (perf / baseline) * 100.0 + baseline_info: TestEnvironmentInfo = baseline.attrs['info'] + output += format_env_diff(info.diff(baseline_info)) + '\n\n' - print(info) + # These raw stats are still valuable to have, but are confusing + # when making relative comparisons + perf = perf.drop_vars(['latency_total', 'num_msgs']) for _, config_ds in perf.groupby('config'): for _, comms_ds in config_ds.groupby('comms'): - print(comms_ds.squeeze().to_dataframe()) - print("\n") - print("\n") + output += str(comms_ds.squeeze().to_dataframe()) + '\n\n' + output += '\n' + + print(output) + + if ai: + print('Querying ChatGPT for AI-assisted analysis of performance test results') + print(chatgpt_analyze_results(output)) def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: @@ -106,8 +122,14 @@ def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: default=None, help="baseline perf test for comparison", ) + p_summary.add_argument( + "--ai", + action="store_true", + help="ask chatgpt for an analysis of the results. requires OPENAI_API_KEY set in environment" + ) p_summary.set_defaults(_handler=lambda ns: summary( perf_path = ns.perf, - baseline_path = ns.baseline + baseline_path = ns.baseline, + ai = ns.ai )) \ No newline at end of file From 386461a1c6fa73928aa5923bbabc2d01b2dfb72e Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 17:40:44 -0400 Subject: [PATCH 45/88] html analysis output; ai was a bust --- src/ezmsg/util/perf/ai.py | 111 ------------ src/ezmsg/util/perf/analysis.py | 300 ++++++++++++++++++++++++++++++-- src/ezmsg/util/perf/run.py | 23 +++ 3 files changed, 312 insertions(+), 122 deletions(-) delete mode 100644 src/ezmsg/util/perf/ai.py diff --git a/src/ezmsg/util/perf/ai.py b/src/ezmsg/util/perf/ai.py deleted file mode 100644 index a212b81c..00000000 --- a/src/ezmsg/util/perf/ai.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import json -import textwrap -import urllib.request - -DEFAULT_TEST_DESCRIPTION = """\ -You are analyzing performance test results for the ezmsg pub/sub system. - -Configurations (config): -- fanin: Many publishers to one subscriber -- fanout: one publisher to many subscribers -- relay: one publisher to one subscriber through many relays - -Communication strategies (comms): -- local: all subs, relays, and pubs are in the SAME process -- shm / tcp: some clients move to a second process; comms via shared memory / TCP - * fanin: all publishers moved - * fanout: all subscribers moved - * relay: the publisher and all relay nodes moved -- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively - -Variables: -- n_clients: pubs (fanin), subs (fanout), or relays (relay) -- msg_size: nominal message size (bytes) - -Metrics: -- sample_rate: messages/sec at the sink (higher = better) -- data_rate: bytes/sec at the sink (higher = better) -- latency_mean: average send -> receive latency in seconds (lower = better) - -Task: -Summarize performance test results. Explain trade-offs by comms/config, call out anomalies/outliers. -Keep the tone concise, technical, and actionable. -Please format output such that it will display nicely in a terminal output. - -If the rest results are a "PERFORMANCE COMPARISON": -- Metrics are in percentages (100.0 = 100 percent = no change) and do not reflect the ground-truth physical units. -- Summarize key improvements/regressions -- Performance differences +/- 5 percent are likely in the noise. -""" - -def chatgpt_analyze_results( - results_text: str, - *, - prompt: str | None = None, - model: str | None = None, - max_chars: int = 120_000, - temperature: float = 0.2, -) -> str: - """ - Send results + a test description to OpenAI's Responses API and print the analysis. - - Env vars: - - OPENAI_API_KEY (required) - - OPENAI_MODEL (optional; e.g., 'gpt-4o-mini' or a newer model) - """ - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise RuntimeError("Please set OPENAI_API_KEY in your environment.") - - model = model or os.getenv("OPENAI_MODEL", "gpt-4o-mini") - - # Keep requests reasonable in size - results_snippet = results_text if len(results_text) <= max_chars else ( - results_text[:max_chars] + "\n\n[...truncated for token budget...]" - ) - - # You can tweak the system instruction here to steer tone/format - system_instruction = "You are a senior performance engineer. Prefer precise, structured analysis." - - user_payload = textwrap.dedent(f"""\ - {prompt or DEFAULT_TEST_DESCRIPTION} - - === BEGIN RESULTS === - {results_snippet} - === END RESULTS === - """) - - body = { - "model": model, - "temperature": temperature, - "input": [ - {"role": "system", "content": [{"type": "text", "text": system_instruction}]}, - {"role": "user", "content": [{"type": "text", "text": user_payload}]}, - ], - } - - req = urllib.request.Request( - "https://api.openai.com/v1/responses", - data=json.dumps(body).encode("utf-8"), - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - }, - method="POST", - ) - - with urllib.request.urlopen(req) as resp: - data = json.load(resp) - - # Robust extraction: prefer output_text; fall back to concatenating output content - text = data.get("output_text") - if not text: - parts = [] - for item in data.get("output", []) or []: - for c in item.get("content", []) or []: - if c.get("type") in ("output_text", "text") and "text" in c: - parts.append(c["text"]) - text = "\n".join(parts) if parts else json.dumps(data, indent=2) - - return text diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index be0cb1e7..2ddc7062 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -2,12 +2,16 @@ import typing import dataclasses import argparse +import html +import math +import webbrowser +import tempfile from pathlib import Path from ..messagecodec import MessageDecoder from .envinfo import TestEnvironmentInfo, format_env_diff -from .ai import chatgpt_analyze_results +from .run import get_datestamp from .impl import ( TestParameters, Metrics, @@ -17,6 +21,7 @@ try: import xarray as xr + import pandas as pd # xarray depends on pandas except ImportError: ez.logger.error('ezmsg perf analysis requires xarray') raise @@ -74,8 +79,186 @@ def load_perf(perf: Path) -> xr.Dataset: dataset = xr.Dataset(data_vars, attrs = dict(info = info)) return dataset +NOISE_BAND_PCT = 5.0 # +/-5% is "in the noise" for comparisons -def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> None: +def _escape(s: str) -> str: + return html.escape(str(s), quote=True) + +def _env_block(title: str, body: str) -> str: + return f""" +
+

{_escape(title)}

+
{_escape(body).strip()}
+
+ """ + +def _legend_block() -> str: + return """ +
+

Legend

+
    +
  • Comparison mode: values are percentages (100 = no change).
  • +
  • Noise band: ±5% considered negligible (no color).
  • +
  • Green: improvement (↑ sample/data rate, ↓ latency).
  • +
  • Red: regression (↓ sample/data rate, ↑ latency).
  • +
+
+ """ + +def _base_css() -> str: + # Minimal, print-friendly CSS + color scales for cells. + return """ + + """ + +def _color_for_comparison(value: float, metric: str) -> str: + """ + Returns inline CSS background for a comparison % value. + value: e.g., 97.3, 104.8, etc. + For sample_rate/data_rate: improvement > 100 (good). + For latency_mean: improvement < 100 (good). + Noise band ±5% around 100 is neutral. + """ + if not (isinstance(value, (int, float)) and math.isfinite(value)): + return "" + + delta = value - 100.0 + # Determine direction: + is good for sample/data; - is good for latency + if metric in ("sample_rate", "data_rate"): + # positive delta good, negative bad + magnitude = abs(delta) + sign_good = delta > 0 + elif metric == "latency_mean": + # negative delta good (lower latency) + magnitude = abs(delta) + sign_good = delta < 0 + else: + return "" + + # Noise band: keep neutral + if magnitude <= NOISE_BAND_PCT: + return "" + + # Scale 5%..50% across 0..1; clamp + scale = max(0.0, min(1.0, (magnitude - NOISE_BAND_PCT) / 45.0)) + + # Choose hue and lightness; use HSL with gentle saturation + hue = "var(--green)" if sign_good else "var(--red)" + # opacity via alpha blend on lightness via HSLa + # Use saturation ~70%, lightness around 40–50% blended with table bg + alpha = 0.15 + 0.35 * scale # 0.15..0.50 + return f"background-color: hsla({hue}, 70%, 45%, {alpha});" + +def _format_number(x) -> str: + if isinstance(x, (int,)) and not isinstance(x, bool): + return f"{x:d}" + try: + xf = float(x) + except Exception: + return _escape(str(x)) + # Heuristic: for comparison percentages, 1 decimal is nice; for absolute, 3 decimals for latency. + return f"{xf:.3f}" + + +def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None: """ print perf test results and comparisons to the console """ output = '' @@ -84,17 +267,21 @@ def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> No info: TestEnvironmentInfo = perf.attrs['info'] output += str(info) + '\n\n' - + relative = False + env_diff = None if baseline_path is not None: + relative = True output += "PERFORMANCE COMPARISON\n\n" baseline = load_perf(baseline_path) perf = (perf / baseline) * 100.0 baseline_info: TestEnvironmentInfo = baseline.attrs['info'] - output += format_env_diff(info.diff(baseline_info)) + '\n\n' + env_diff = format_env_diff(info.diff(baseline_info)) + output += env_diff + '\n\n' # These raw stats are still valuable to have, but are confusing # when making relative comparisons perf = perf.drop_vars(['latency_total', 'num_msgs']) + df = perf.squeeze().to_dataframe() for _, config_ds in perf.groupby('config'): for _, comms_ds in config_ds.groupby('comms'): @@ -103,9 +290,100 @@ def summary(perf_path: Path, baseline_path: Path | None, ai: bool = False) -> No print(output) - if ai: - print('Querying ChatGPT for AI-assisted analysis of performance test results') - print(chatgpt_analyze_results(output)) + if html: + # Ensure expected columns exist + expected_cols = {"sample_rate", "data_rate", "latency_mean"} + missing = expected_cols - set(df.columns) + if missing: + raise ValueError(f"Missing expected columns in dataset: {missing}") + + # We'll render a table per (config, comms) group. + groups = df.reset_index().sort_values( + by=["config", "comms", "n_clients", "msg_size"] + ).groupby(["config", "comms"], sort=False) + + # Build HTML + parts: list[str] = [] + parts.append("") + parts.append("") + parts.append("ezmsg perf report") + parts.append(_base_css()) + parts.append("
") + + parts.append("
") + parts.append("

ezmsg Performance Report

") + sub = str(perf_path) + if baseline_path is not None: + sub += f' relative to {str(baseline_path)}' + parts.append(f"
{_escape(sub)}
") + parts.append("
") + + if info is not None: + parts.append(_env_block("Test Environment", str(info))) + + if env_diff is not None: + # Show diffs using your helper + parts.append("
") + parts.append("

Environment Differences vs Baseline

") + parts.append(f"
{_escape(env_diff)}
") + parts.append("
") + parts.append(_legend_block()) + + # Render each group + for (config, comms), g in groups: + # Keep only expected columns in order + cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean"] + g = g[cols].copy() + + # String format some columns (msg_size with separators) + g["msg_size"] = g["msg_size"].map(lambda x: f"{int(x):,}" if pd.notna(x) else x) + + # Build table manually so we can inject inline cell styles easily + # (pandas Styler is great but produces bulky HTML; manual keeps it clean) + header = f""" + + + n_clients + msg_size {'' if relative else '(b)'} + sample_rate {'' if relative else '(msgs/s)'} + data_rate {'' if relative else '(MB/s)'} + latency_mean {'' if relative else '(us)'} + + + """ + body_rows: list[str] = [] + for _, row in g.iterrows(): + sr, dr, lt = row["sample_rate"], row["data_rate"], row["latency_mean"] + dr = dr if relative else dr / 2**20 + lt = lt if relative else lt * 1e6 + sr_style = _color_for_comparison(sr, "sample_rate") if relative else "" + dr_style = _color_for_comparison(dr, "data_rate") if relative else "" + lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" + + body_rows.append( + "" + f"{_format_number(row['n_clients'])}" + f"{_escape(row['msg_size'])}" + f"{_format_number(sr)}" + f"{_format_number(dr)}" + f"{_format_number(lt)}" + "" + ) + table_html = f"{header}{''.join(body_rows)}
" + + parts.append( + f"

" + f"{_escape(config)}" + f"{_escape(comms)}" + f"

{table_html}
" + ) + + parts.append("
") + html_text = "".join(parts) + + out_path = Path(f'report_{get_datestamp()}.html') + out_path.write_text(html_text, encoding="utf-8") + webbrowser.open(out_path.resolve().as_uri()) def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: @@ -123,13 +401,13 @@ def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: help="baseline perf test for comparison", ) p_summary.add_argument( - "--ai", - action="store_true", - help="ask chatgpt for an analysis of the results. requires OPENAI_API_KEY set in environment" + "--html", + action = 'store_true', + help = "generate an html output file and render results in browser", ) p_summary.set_defaults(_handler=lambda ns: summary( perf_path = ns.perf, baseline_path = ns.baseline, - ai = ns.ai + html = ns.html )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index d36dd02c..1856f2a6 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -25,6 +25,29 @@ class PerfRunArgs: num_buffers: int def perf_run(args: PerfRunArgs) -> None: + """ + Configurations (config): + - fanin: Many publishers to one subscriber + - fanout: one publisher to many subscribers + - relay: one publisher to one subscriber through many relays + + Communication strategies (comms): + - local: all subs, relays, and pubs are in the SAME process + - shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved + - shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + + Variables: + - n_clients: pubs (fanin), subs (fanout), or relays (relay) + - msg_size: nominal message size (bytes) + + Metrics: + - sample_rate: messages/sec at the sink (higher = better) + - data_rate: bytes/sec at the sink (higher = better) + - latency_mean: average send -> receive latency in seconds (lower = better) + """ msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] From d01deb7845bbf44dabcbd459a9cbd5fb40b10cc5 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sat, 20 Sep 2025 17:49:44 -0400 Subject: [PATCH 46/88] report tweaks --- src/ezmsg/util/perf/analysis.py | 27 ++++++++++++++++++++++++++- src/ezmsg/util/perf/run.py | 24 ------------------------ 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 2ddc7062..9d33a068 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -5,7 +5,6 @@ import html import math import webbrowser -import tempfile from pathlib import Path @@ -32,6 +31,30 @@ ez.logger.error('ezmsg perf analysis requires numpy') raise +TEST_DESCRIPTION = """ +Configurations (config): +- fanin: many publishers to one subscriber +- fanout: one publisher to many subscribers +- relay: one publisher to one subscriber through many relays + +Communication strategies (comms): +- local: all subs, relays, and pubs are in the SAME process +- shm / tcp: some clients move to a second process; comms via shared memory / TCP + * fanin: all publishers moved + * fanout: all subscribers moved + * relay: the publisher and all relay nodes moved +- shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively + +Variables: +- n_clients: pubs (fanin), subs (fanout), or relays (relay) +- msg_size: nominal message size (bytes) + +Metrics: +- sample_rate: messages/sec at the sink (higher = better) +- data_rate: bytes/sec at the sink (higher = better) +- latency_mean: average send -> receive latency in seconds (lower = better) +""" + def load_perf(perf: Path) -> xr.Dataset: @@ -321,6 +344,8 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> if info is not None: parts.append(_env_block("Test Environment", str(info))) + parts.append(_env_block("Test Details", TEST_DESCRIPTION)) + if env_diff is not None: # Show diffs using your helper parts.append("
") diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 1856f2a6..d7280e51 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -25,30 +25,6 @@ class PerfRunArgs: num_buffers: int def perf_run(args: PerfRunArgs) -> None: - """ - Configurations (config): - - fanin: Many publishers to one subscriber - - fanout: one publisher to many subscribers - - relay: one publisher to one subscriber through many relays - - Communication strategies (comms): - - local: all subs, relays, and pubs are in the SAME process - - shm / tcp: some clients move to a second process; comms via shared memory / TCP - * fanin: all publishers moved - * fanout: all subscribers moved - * relay: the publisher and all relay nodes moved - - shm_spread / tcp_spread: each client in its own process; comms via SHM / TCP respectively - - Variables: - - n_clients: pubs (fanin), subs (fanout), or relays (relay) - - msg_size: nominal message size (bytes) - - Metrics: - - sample_rate: messages/sec at the sink (higher = better) - - data_rate: bytes/sec at the sink (higher = better) - - latency_mean: average send -> receive latency in seconds (lower = better) - """ - msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] comms = [c for c in Communication] From 2792fa6fa616772b49ce3769b3d5fd985c7d3c7d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 18 Sep 2025 14:02:37 -0400 Subject: [PATCH 47/88] local comms with no tcp ack --- src/ezmsg/core/messagechannel.py | 76 +++++++++++++++++++++++--------- src/ezmsg/core/pubclient.py | 27 ++++++++---- src/ezmsg/core/subclient.py | 4 +- 3 files changed, 76 insertions(+), 31 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index c55b4b1d..a55e15ff 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -32,7 +32,7 @@ class CacheMiss(Exception): ... -class _Channel: +class Channel: """cache-backed message channel for a particular publisher""" id: UUID @@ -51,6 +51,7 @@ class _Channel: _pub_task: asyncio.Task[None] _pub_writer: asyncio.StreamWriter _graph_address: AddressType | None + _local_backpressure: Backpressure | None = None def __init__( self, @@ -58,7 +59,8 @@ def __init__( pub_id: UUID, num_buffers: int, shm: SHMContext | None, - graph_address: AddressType | None + graph_address: AddressType | None, + local_backpressure: Backpressure | None = None, ) -> None: self.id = id self.pub_id = pub_id @@ -70,13 +72,15 @@ def __init__( self.backpressure = Backpressure(self.num_buffers) self.clients = dict() self._graph_address = graph_address + self._local_backpressure = local_backpressure @classmethod async def create( cls, pub_id: UUID, graph_address: AddressType, - ) -> "_Channel": + local_backpressure: Backpressure | None = None + ) -> "": graph_service = GraphService(graph_address) graph_reader, graph_writer = await graph_service.open_connection() @@ -114,7 +118,7 @@ async def create( num_buffers = await read_int(reader) - chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address) + chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address, local_backpressure) chan._graph_task = asyncio.create_task( chan._graph_connection(graph_reader, graph_writer), @@ -213,21 +217,22 @@ def _notify_clients(self, msg_id: int) -> None: self.backpressure.lease(client_id, msg_id % self.num_buffers) queue.put_nowait((self.pub_id, msg_id)) - def put_local(self, msg_id: int, msg: typing.Any) -> bool: + def put_local(self, msg_id: int, msg: typing.Any) -> None: """ put an object into cache (should only be used by Publishers) returns true if any clients were notified """ - self._notify_clients(msg_id) - - # if buffer is still available after notify, no subs were notified here - buf_idx = msg_id % self.num_buffers - if self.backpressure.available(buf_idx): - return False + if self._local_backpressure is None: + raise ValueError('cannot put_local without access to publisher backpressure (is publisher in same process?)') + buf_idx = msg_id % self.num_buffers self.cache_id[buf_idx] = msg_id self.cache[buf_idx] = msg - return True + self._notify_clients(msg_id) + + # if buffer is not available after notify, subs were notified + if not self.backpressure.available(buf_idx): + self._local_backpressure.lease(self.id, buf_idx) @contextmanager def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: @@ -269,6 +274,12 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: + + # If pub is in same process as this channel, avoid TCP + if self._local_backpressure is not None: + self._local_backpressure.free(self.id, buf_idx) + return + try: ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) self._pub_writer.write(ack) @@ -290,6 +301,10 @@ def unregister_client(self, client_id: UUID) -> None: self.backpressure.free(client_id) + elif client_id == self.pub_id and self._local_backpressure is not None: + self._local_backpressure.free(self.id) + self._local_backpressure = None + del self.clients[client_id] def clear_cache(self): @@ -307,9 +322,9 @@ def _ensure_address(address: AddressType | None) -> Address: return address -class _ChannelManager: +class ChannelManager: - _registry: typing.Dict[Address, typing.Dict[UUID, _Channel]] + _registry: typing.Dict[Address, typing.Dict[UUID, Channel]] def __init__(self): default_address = Address.from_string(GRAPHSERVER_ADDR) @@ -320,27 +335,46 @@ async def get( self, pub_id: UUID, graph_address: AddressType | None = None, + local_backpressure: Backpressure | None = None, create: bool = False - ) -> _Channel: + ) -> Channel: graph_address = _ensure_address(graph_address) channel = self._registry.get(graph_address, dict()).get(pub_id, None) if create and channel is None: - channel = await _Channel.create(pub_id, graph_address) + channel = await Channel.create(pub_id, graph_address, local_backpressure) channels = self._registry.get(graph_address, dict()) channels[pub_id] = channel self._registry[graph_address] = channels if channel is None: raise KeyError(f"channel {pub_id=} {graph_address=} does not exist") return channel - + async def register( + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue, + graph_address: AddressType | None = None, + ) -> Channel: + return await self._register(pub_id, client_id, queue, graph_address, None) + + async def register_local_pub( + self, + pub_id: UUID, + local_backpressure: Backpressure | None = None, + graph_address: AddressType | None = None, + ) -> Channel: + return await self._register(pub_id, pub_id, None, graph_address, local_backpressure) + + async def _register( self, pub_id: UUID, client_id: UUID, queue: NotificationQueue | None = None, - graph_address: AddressType | None = None - ) -> _Channel: - channel = await self.get(pub_id, graph_address, create = True) + graph_address: AddressType | None = None, + local_backpressure: Backpressure | None = None + ) -> Channel: + channel = await self.get(pub_id, graph_address, create = True, local_backpressure = local_backpressure) channel.register_client(client_id, queue) return channel @@ -362,4 +396,4 @@ async def unregister( logger.debug(f'closed channel {pub_id}: no clients') -CHANNELS = _ChannelManager() \ No newline at end of file +CHANNELS = ChannelManager() \ No newline at end of file diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 5087fa97..4d3c2a6b 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -10,7 +10,7 @@ from .backpressure import Backpressure from .shm import SHMContext from .graphserver import GraphService -from .messagechannel import CHANNELS, _Channel +from .messagechannel import CHANNELS, Channel from .messagemarshal import MessageMarshal from .netprotocol import ( @@ -55,7 +55,7 @@ class Publisher: _connection_task: "asyncio.Task[None]" _channels: Dict[UUID, PubChannelInfo] _channel_tasks: Dict[UUID, "asyncio.Task[None]"] - _local_channel: _Channel + _local_channel: Channel _address: Address _backpressure: Backpressure _num_buffers: int @@ -80,7 +80,8 @@ async def create( port: Optional[int] = None, buf_size: int = DEFAULT_SHM_SIZE, num_buffers: int = 32, - **kwargs, + start_paused: bool = False, + force_tcp: bool = False, ) -> "Publisher": graph_service = GraphService(graph_address) reader, writer = await graph_service.open_connection() @@ -90,7 +91,15 @@ async def create( writer.write(encode_str(topic)) pub_id = UUID(await read_str(reader)) - pub = cls(pub_id, topic, shm, graph_address, num_buffers, **kwargs) + pub = cls( + id = pub_id, + topic = topic, + shm = shm, + graph_address = graph_address, + num_buffers = num_buffers, + start_paused = start_paused, + force_tcp = force_tcp + ) start_port = int( os.getenv(PUBLISHER_START_PORT_ENV, PUBLISHER_START_PORT_DEFAULT) @@ -115,8 +124,10 @@ async def create( name = f'pub-{pub.id}: _graph_connection' ) - pub._local_channel = await CHANNELS.register( - pub.id, pub.id, None, pub._graph_address + pub._local_channel = await CHANNELS.register_local_pub( + pub_id = pub.id, + local_backpressure = None if force_tcp else pub._backpressure, + graph_address = pub._graph_address, ) logger.debug(f'created pub {pub.id=} {topic=} {channel_server_address=}') @@ -249,6 +260,7 @@ async def _handle_channel( elif msg == Command.RX_ACK.value: msg_id = await read_int(reader) self._backpressure.free(info.id, msg_id % self._num_buffers) + logger.info('TCP_ACK') except (ConnectionResetError, BrokenPipeError): logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") @@ -288,8 +300,7 @@ async def broadcast(self, obj: Any) -> None: # Get local channel and put variable there for local tx if not self._force_tcp: - if self._local_channel.put_local(self._msg_id, obj): - self._backpressure.lease(self._local_channel.id, buf_idx) + self._local_channel.put_local(self._msg_id, obj) if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index ebed727e..d5a9f5b4 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -7,7 +7,7 @@ from copy import deepcopy from .graphserver import GraphService -from .messagechannel import CHANNELS, NotificationQueue, _Channel +from .messagechannel import CHANNELS, NotificationQueue, Channel from .netprotocol import ( AddressType, @@ -38,7 +38,7 @@ class Subscriber: # NOTE: This is an optimization to retain a local handle to channels # so that dict lookup and wrapper contextmanager aren't in hotpath - _channels: typing.Dict[UUID, _Channel] + _channels: typing.Dict[UUID, Channel] @classmethod async def create( From 07ad4492ae517886d171582e888d8ddcbaddf190 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 18 Sep 2025 14:03:55 -0400 Subject: [PATCH 48/88] remove debug statement --- src/ezmsg/core/pubclient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 4d3c2a6b..cca6f001 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -260,7 +260,6 @@ async def _handle_channel( elif msg == Command.RX_ACK.value: msg_id = await read_int(reader) self._backpressure.free(info.id, msg_id % self._num_buffers) - logger.info('TCP_ACK') except (ConnectionResetError, BrokenPipeError): logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") From 380e3b08ce547432f4fed9226e7ddb2d4d1136cf Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 18 Sep 2025 15:35:25 -0400 Subject: [PATCH 49/88] bugfix: local backpressure --- src/ezmsg/core/messagechannel.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index a55e15ff..1ce56272 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -51,7 +51,7 @@ class Channel: _pub_task: asyncio.Task[None] _pub_writer: asyncio.StreamWriter _graph_address: AddressType | None - _local_backpressure: Backpressure | None = None + _local_backpressure: Backpressure | None def __init__( self, @@ -60,7 +60,6 @@ def __init__( num_buffers: int, shm: SHMContext | None, graph_address: AddressType | None, - local_backpressure: Backpressure | None = None, ) -> None: self.id = id self.pub_id = pub_id @@ -72,15 +71,14 @@ def __init__( self.backpressure = Backpressure(self.num_buffers) self.clients = dict() self._graph_address = graph_address - self._local_backpressure = local_backpressure + self._local_backpressure = None @classmethod async def create( cls, pub_id: UUID, graph_address: AddressType, - local_backpressure: Backpressure | None = None - ) -> "": + ) -> "Channel": graph_service = GraphService(graph_address) graph_reader, graph_writer = await graph_service.open_connection() @@ -118,7 +116,7 @@ async def create( num_buffers = await read_int(reader) - chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address, local_backpressure) + chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address) chan._graph_task = asyncio.create_task( chan._graph_connection(graph_reader, graph_writer), @@ -226,12 +224,12 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: raise ValueError('cannot put_local without access to publisher backpressure (is publisher in same process?)') buf_idx = msg_id % self.num_buffers - self.cache_id[buf_idx] = msg_id - self.cache[buf_idx] = msg self._notify_clients(msg_id) # if buffer is not available after notify, subs were notified if not self.backpressure.available(buf_idx): + self.cache_id[buf_idx] = msg_id + self.cache[buf_idx] = msg self._local_backpressure.lease(self.id, buf_idx) @contextmanager @@ -286,8 +284,15 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None except (BrokenPipeError, ConnectionResetError): logger.info(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") - def register_client(self, client_id: UUID, queue: NotificationQueue | None = None) -> None: - self.clients[client_id] = queue + def register_client( + self, + client_id: UUID, + queue: NotificationQueue | None = None, + local_backpressure: Backpressure | None = None, + ) -> None: + self.clients[client_id] = queue + if client_id == self.pub_id: + self._local_backpressure = local_backpressure def unregister_client(self, client_id: UUID) -> None: queue = self.clients[client_id] @@ -335,13 +340,12 @@ async def get( self, pub_id: UUID, graph_address: AddressType | None = None, - local_backpressure: Backpressure | None = None, create: bool = False ) -> Channel: graph_address = _ensure_address(graph_address) channel = self._registry.get(graph_address, dict()).get(pub_id, None) if create and channel is None: - channel = await Channel.create(pub_id, graph_address, local_backpressure) + channel = await Channel.create(pub_id, graph_address) channels = self._registry.get(graph_address, dict()) channels[pub_id] = channel self._registry[graph_address] = channels @@ -374,8 +378,8 @@ async def _register( graph_address: AddressType | None = None, local_backpressure: Backpressure | None = None ) -> Channel: - channel = await self.get(pub_id, graph_address, create = True, local_backpressure = local_backpressure) - channel.register_client(client_id, queue) + channel = await self.get(pub_id, graph_address, create = True) + channel.register_client(client_id, queue, local_backpressure) return channel async def unregister( From b05096ae289461adf121f08b38721886edd9e249 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Sun, 21 Sep 2025 09:52:18 -0400 Subject: [PATCH 50/88] added test iters --- src/ezmsg/util/perf/analysis.py | 23 ++++++++------- src/ezmsg/util/perf/impl.py | 5 ++++ src/ezmsg/util/perf/run.py | 52 +++++++++++++++++++-------------- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 9d33a068..681c9dac 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -14,6 +14,7 @@ from .impl import ( TestParameters, Metrics, + TestLogEntry, ) import ezmsg.core as ez @@ -59,15 +60,15 @@ def load_perf(perf: Path) -> xr.Dataset: params: typing.List[TestParameters] = [] - results: typing.List[Metrics] = [] + results: typing.List[typing.List[Metrics]] = [] with open(perf, 'r') as perf_f: info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) for line in perf_f: - obj = json.loads(line, cls = MessageDecoder) - params.append(obj['params']) - results.append(obj['results']) + obj: TestLogEntry = json.loads(line, cls = MessageDecoder) + params.append(obj.params) + results.append(obj.results) n_clients_axis = list(sorted(set([p.n_clients for p in params]))) msg_size_axis = list(sorted(set([p.msg_size for p in params]))) @@ -91,19 +92,20 @@ def load_perf(perf: Path) -> xr.Dataset: len(config_axis) )) * np.nan for p, r in zip(params, results): + # tests are run multiple times; get the median value for each metric + values = list(sorted([getattr(v, field.name) for v in r])) + value = values[len(values)//2] m[ n_clients_axis.index(p.n_clients), msg_size_axis.index(p.msg_size), comms_axis.index(p.comms), config_axis.index(p.config) - ] = getattr(r, field.name) + ] = value data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) dataset = xr.Dataset(data_vars, attrs = dict(info = info)) return dataset -NOISE_BAND_PCT = 5.0 # +/-5% is "in the noise" for comparisons - def _escape(s: str) -> str: return html.escape(str(s), quote=True) @@ -121,7 +123,6 @@ def _legend_block() -> str:

Legend

  • Comparison mode: values are percentages (100 = no change).
  • -
  • Noise band: ±5% considered negligible (no color).
  • Green: improvement (↑ sample/data rate, ↓ latency).
  • Red: regression (↓ sample/data rate, ↑ latency).
@@ -232,7 +233,7 @@ def _base_css() -> str: """ -def _color_for_comparison(value: float, metric: str) -> str: +def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 5.0) -> str: """ Returns inline CSS background for a comparison % value. value: e.g., 97.3, 104.8, etc. @@ -257,11 +258,11 @@ def _color_for_comparison(value: float, metric: str) -> str: return "" # Noise band: keep neutral - if magnitude <= NOISE_BAND_PCT: + if magnitude <= noise_band_pct: return "" # Scale 5%..50% across 0..1; clamp - scale = max(0.0, min(1.0, (magnitude - NOISE_BAND_PCT) / 45.0)) + scale = max(0.0, min(1.0, (magnitude - noise_band_pct) / 45.0)) # Choose hue and lightness; use HSL with gentle saturation hue = "var(--green)" if sign_good else "var(--red)" diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 58357b8c..5edc0474 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -313,3 +313,8 @@ class TestParameters: duration: float num_buffers: int + +@dataclasses.dataclass +class TestLogEntry: + params: TestParameters + results: typing.List[Metrics] \ No newline at end of file diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index d7280e51..599f3b02 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -8,7 +8,8 @@ from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo from .impl import ( - TestParameters, + TestParameters, + TestLogEntry, perform_test, Communication, CONFIGS, @@ -19,12 +20,12 @@ def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") -@dataclass -class PerfRunArgs: - duration: float - num_buffers: int -def perf_run(args: PerfRunArgs) -> None: +def perf_run( + duration: float, + num_buffers: int, + iters: int, +) -> None: msg_sizes = [2 ** exp for exp in range(4, 25, 8)] n_clients = [2 ** exp for exp in range(0, 6, 2)] comms = [c for c in Communication] @@ -44,20 +45,22 @@ def perf_run(args: PerfRunArgs) -> None: n_clients = n_clients, config = config.__name__, comms = comms.value, - duration = args.duration, - num_buffers = args.num_buffers + duration = duration, + num_buffers = num_buffers ) - results = perform_test( - n_clients = n_clients, - duration = args.duration, - msg_size = msg_size, - buffers = args.num_buffers, - comms = comms, - config = config, - ) - - output = dict( + results = [ + perform_test( + n_clients = n_clients, + duration = duration, + msg_size = msg_size, + buffers = num_buffers, + comms = comms, + config = config, + ) for _ in range(iters) + ] + + output = TestLogEntry( params = params, results = results ) @@ -79,10 +82,15 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: default=32, help="shared memory buffers (default = 32)", ) + p_run.add_argument( + "--iters", "-i", + type = int, + default = 3, + help = "number of times to run each test" + ) p_run.set_defaults(_handler=lambda ns: perf_run( - PerfRunArgs( - duration = ns.duration, - num_buffers = ns.num_buffers - ) + duration = ns.duration, + num_buffers = ns.num_buffers, + iters = ns.iters )) \ No newline at end of file From 50756f662faf16fa89dd3f273877a47890203b1b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 13:06:30 -0400 Subject: [PATCH 51/88] more cmdline --- src/ezmsg/util/perf/impl.py | 8 ++- src/ezmsg/util/perf/run.py | 99 +++++++++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 17 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 5edc0474..6bb0ded7 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -190,7 +190,13 @@ def relay(config: ConfigSettings) -> Configuration: return relays, connections -CONFIGS: typing.Iterable[Configurator] = [fanin, fanout, relay] +CONFIGS: typing.Mapping[str, Configurator] = { + c.__name__: c for c in [ + fanin, + fanout, + relay + ] +} class Communication(enum.StrEnum): LOCAL = "local" diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 599f3b02..094389c2 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -2,8 +2,7 @@ import datetime import itertools import argparse - -from dataclasses import dataclass +import typing from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo @@ -17,6 +16,10 @@ import ezmsg.core as ez +DEFAULT_MSG_SIZES = [2 ** exp for exp in range(4, 25, 8)] +DEFAULT_N_CLIENTS = [2 ** exp for exp in range(0, 6, 2)] +DEFAULT_COMMS = [c for c in Communication] + def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") @@ -25,38 +28,62 @@ def perf_run( duration: float, num_buffers: int, iters: int, + msg_sizes: typing.Iterable[int] | None, + n_clients: typing.Iterable[int] | None, + comms: typing.Iterable[str] | None, + configs: typing.Iterable[str] | None, ) -> None: - msg_sizes = [2 ** exp for exp in range(4, 25, 8)] - n_clients = [2 ** exp for exp in range(0, 6, 2)] - comms = [c for c in Communication] - - test_list = list(itertools.product(msg_sizes, n_clients, CONFIGS, comms)) + + if n_clients is None: + n_clients = DEFAULT_N_CLIENTS + if any(c <= 0 for c in n_clients): + ez.logger.error('All tests must have >0 clients') + return + + if msg_sizes is None: + msg_sizes = DEFAULT_MSG_SIZES + if any(s <= 0 for s in msg_sizes): + ez.logger.error('All msg_sizes must be >0 bytes') + + try: + communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] + except ValueError: + ez.logger.error(f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}") + return + + try: + configurators = list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs] + except ValueError: + ez.logger.error(f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}") + return + + test_list = list(itertools.product(msg_sizes, n_clients, configurators, communications)) with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") - for test_idx, (msg_size, n_clients, config, comms) in enumerate(test_list): + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): ez.logger.info(f"RUNNING TEST {test_idx + 1} / {len(test_list)} ({(test_idx / len(test_list)) * 100.0:0.2f} %)") params = TestParameters( msg_size = msg_size, - n_clients = n_clients, - config = config.__name__, - comms = comms.value, + n_clients = clients, + config = conf.__name__, + comms = comm.value, duration = duration, num_buffers = num_buffers ) results = [ perform_test( - n_clients = n_clients, + n_clients = clients, duration = duration, msg_size = msg_size, buffers = num_buffers, - comms = comms, - config = config, + comms = comm, + config = conf, ) for _ in range(iters) ] @@ -70,27 +97,67 @@ def perf_run( def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: p_run = subparsers.add_parser("run", help="run performance test") + p_run.add_argument( "--duration", type=float, default=2.0, help="individual test duration in seconds (default = 2.0)", ) + p_run.add_argument( "--num-buffers", type=int, default=32, help="shared memory buffers (default = 32)", ) + p_run.add_argument( "--iters", "-i", type = int, default = 3, - help = "number of times to run each test" + help = "number of times to run each test (default = 3)" ) + p_run.add_argument( + "--msg-sizes", + type = int, + default = None, + nargs = "*", + help = f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})" + ) + + p_run.add_argument( + "--n-clients", + type = int, + default = None, + nargs = "*", + help = f"number of clients (default = {DEFAULT_N_CLIENTS})" + ) + + p_run.add_argument( + "--comms", + type = str, + default = None, + nargs = "*", + help = f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})" + ) + + p_run.add_argument( + "--configs", + type = str, + default = None, + nargs = "*", + help = f"configurations to test (default = {[c for c in CONFIGS]})" + ) + + p_run.set_defaults(_handler=lambda ns: perf_run( duration = ns.duration, num_buffers = ns.num_buffers, - iters = ns.iters + iters = ns.iters, + msg_sizes = ns.msg_sizes, + n_clients = ns.n_clients, + comms = ns.comms, + configs = ns.configs, )) \ No newline at end of file From 8823c61001c2c5fb2386f370d56dd6aca288a219 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 15:00:11 -0400 Subject: [PATCH 52/88] changing sample_rate calculation --- src/ezmsg/util/perf/impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 6bb0ded7..8421a51b 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -264,10 +264,10 @@ def perform_test( process_components = process_components, ) - return calculate_metrics(sink) + return calculate_metrics(sink, duration) -def calculate_metrics(sink: LoadTestSink) -> Metrics: +def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: # Log some useful summary statistics min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) @@ -286,7 +286,7 @@ def calculate_metrics(sink: LoadTestSink) -> Metrics: num_samples = len(sink.STATE.received_data) ez.logger.info(f"Samples received: {num_samples}") - sample_rate = num_samples / (max_timestamp - min_timestamp) + sample_rate = num_samples / duration ez.logger.info(f"Sample rate: {sample_rate} Hz") latency_mean = total_latency / num_samples ez.logger.info(f"Mean latency: {latency_mean} s") From de379160cf8d3e495f897db3ec5fe32d00be423f Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 16:26:24 -0400 Subject: [PATCH 53/88] bugfix: force_tcp on local channel causes backpressure --- src/ezmsg/core/messagechannel.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 1ce56272..5d22936b 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -200,20 +200,25 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: self.cache[buf_idx] = obj self.cache_id[buf_idx] = msg_id - self._notify_clients(msg_id) + if not self._notify_clients(msg_id): + # Nobody is listening; need to ack! + self._acknowledge(msg_id) except (ConnectionResetError, BrokenPipeError): logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") finally: await close_stream_writer(self._pub_writer) - logger.debug(f"disconnected: channel:{self.id} -> pub:{id}") + logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") - def _notify_clients(self, msg_id: int) -> None: + def _notify_clients(self, msg_id: int) -> bool: + """ notify interested clients and return true if any were notified """ + buf_idx = msg_id % self.num_buffers for client_id, queue in self.clients.items(): if queue is None: continue # queue is none if this is the pub - self.backpressure.lease(client_id, msg_id % self.num_buffers) + self.backpressure.lease(client_id, buf_idx) queue.put_nowait((self.pub_id, msg_id)) + return not self.backpressure.available(buf_idx) def put_local(self, msg_id: int, msg: typing.Any) -> None: """ @@ -224,10 +229,7 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: raise ValueError('cannot put_local without access to publisher backpressure (is publisher in same process?)') buf_idx = msg_id % self.num_buffers - self._notify_clients(msg_id) - - # if buffer is not available after notify, subs were notified - if not self.backpressure.available(buf_idx): + if self._notify_clients(msg_id): self.cache_id[buf_idx] = msg_id self.cache[buf_idx] = msg self._local_backpressure.lease(self.id, buf_idx) @@ -278,11 +280,14 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None self._local_backpressure.free(self.id, buf_idx) return - try: - ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) - self._pub_writer.write(ack) - except (BrokenPipeError, ConnectionResetError): - logger.info(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") + self._acknowledge(msg_id) + + def _acknowledge(self, msg_id: int) -> None: + try: + ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) + self._pub_writer.write(ack) + except (BrokenPipeError, ConnectionResetError): + logger.info(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") def register_client( self, From d714d818dda5a9fd8dffe5710d82354f0da0b28e Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 16:52:50 -0400 Subject: [PATCH 54/88] fix: n-clients = 0 is useful and works --- src/ezmsg/util/perf/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 094389c2..3ff70e91 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -36,8 +36,8 @@ def perf_run( if n_clients is None: n_clients = DEFAULT_N_CLIENTS - if any(c <= 0 for c in n_clients): - ez.logger.error('All tests must have >0 clients') + if any(c < 0 for c in n_clients): + ez.logger.error('All tests must have >=0 clients') return if msg_sizes is None: From 4bfd958030d947973af641d19fb2c0a4810a7f4f Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 19:10:46 -0400 Subject: [PATCH 55/88] better support for msg_size = 0 and n_clients = 0 --- src/ezmsg/util/perf/impl.py | 10 ++++++---- src/ezmsg/util/perf/run.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 8421a51b..faec5d19 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -183,10 +183,12 @@ def relay(config: ConfigSettings) -> Configuration: connections: ez.NetworkDefinition = [] relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] - connections.append((config.source.OUTPUT, relays[0].INPUT)) - for from_relay, to_relay in zip(relays[:-1], relays[1:]): - connections.append((from_relay.OUTPUT, to_relay.INPUT)) - connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + if len(relays): + connections.append((config.source.OUTPUT, relays[0].INPUT)) + for from_relay, to_relay in zip(relays[:-1], relays[1:]): + connections.append((from_relay.OUTPUT, to_relay.INPUT)) + connections.append((relays[-1].OUTPUT, config.sink.INPUT)) + else: connections.append((config.source.OUTPUT, config.sink.INPUT)) return relays, connections diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 3ff70e91..80bf80cb 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -42,8 +42,8 @@ def perf_run( if msg_sizes is None: msg_sizes = DEFAULT_MSG_SIZES - if any(s <= 0 for s in msg_sizes): - ez.logger.error('All msg_sizes must be >0 bytes') + if any(s < 0 for s in msg_sizes): + ez.logger.error('All msg_sizes must be >=0 bytes') try: communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] From 4dac9f9c932e22e7af6a5ad633b58612a3b0915f Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 19:49:13 -0400 Subject: [PATCH 56/88] also force_tcp on tcp_spread oops --- src/ezmsg/util/perf/impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index faec5d19..114a8b7a 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -220,7 +220,7 @@ def perform_test( dynamic_size = int(msg_size), duration = duration, buffers = buffers, - force_tcp = (comms == Communication.TCP), + force_tcp = (comms in (Communication.TCP, Communication.TCP_SPREAD)), ) source = LoadTestSource(settings) From 1e54cfc4bf2ad97120a838b85a93915636244507 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 20:38:43 -0400 Subject: [PATCH 57/88] tcp benchmarks un-necessarily hamstrung by extra TCP transmission to empty local channel --- src/ezmsg/core/pubclient.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index cca6f001..273d5631 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -298,8 +298,7 @@ async def broadcast(self, obj: Any) -> None: await self._backpressure.wait(buf_idx) # Get local channel and put variable there for local tx - if not self._force_tcp: - self._local_channel.put_local(self._msg_id, obj) + self._local_channel.put_local(self._msg_id, obj) if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): @@ -322,7 +321,7 @@ async def broadcast(self, obj: Any) -> None: for channel in self._channels.values(): - if (not self._force_tcp) and self.pid == channel.pid and channel.shm_ok: + if self.pid == channel.pid and channel.shm_ok: continue # Local transmission handled by channel.put elif (not self._force_tcp) and self.pid != channel.pid and channel.shm_ok: From e780e8a21d239307a5c1c72f91d15485480cd7c5 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 20:41:09 -0400 Subject: [PATCH 58/88] viztracer is useful for profiling --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 961271b4..c0888f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dev = [ {include-group = "lint"}, {include-group = "test"}, "pre-commit>=4.3.0", + "viztracer>=1.0.4", ] lint = [ "flake8>=7.3.0", From eab4cea8191f9863725066b40be042153d8aa837 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 22 Sep 2025 21:33:34 -0400 Subject: [PATCH 59/88] force_tcp bugfix --- src/ezmsg/core/pubclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 273d5631..78366547 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -126,7 +126,7 @@ async def create( pub._local_channel = await CHANNELS.register_local_pub( pub_id = pub.id, - local_backpressure = None if force_tcp else pub._backpressure, + local_backpressure = pub._backpressure, graph_address = pub._graph_address, ) From 1d59244ecbd78e0dac01785e41784115cd90d97e Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 09:49:28 -0400 Subject: [PATCH 60/88] bugfix for rare error on shutdown --- src/ezmsg/core/messagechannel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 5d22936b..1f121363 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -204,7 +204,7 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: # Nobody is listening; need to ack! self._acknowledge(msg_id) - except (ConnectionResetError, BrokenPipeError): + except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") finally: From fdb89dc3ed00daa317da4884554dae97cb3ced47 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 12:47:33 -0400 Subject: [PATCH 61/88] hotpath optim --- src/ezmsg/core/messagechannel.py | 65 +++++++++++++++++--------------- src/ezmsg/core/messagemarshal.py | 4 ++ 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 1f121363..66b8c5a0 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -7,7 +7,7 @@ from contextlib import contextmanager, suppress from .shm import SHMContext -from .messagemarshal import MessageMarshal +from .messagemarshal import MessageMarshal, NO_MESSAGE from .backpressure import Backpressure from .graphserver import GraphService @@ -41,6 +41,7 @@ class Channel: topic: str num_buffers: int + tcp_cache: typing.List[memoryview] cache: typing.List[typing.Any] cache_id: typing.List[int | None] shm: SHMContext | None @@ -66,6 +67,7 @@ def __init__( self.num_buffers = num_buffers self.shm = shm + self.tcp_cache = [memoryview(NO_MESSAGE)] * self.num_buffers self.cache_id = [None] * self.num_buffers self.cache = [None] * self.num_buffers self.backpressure = Backpressure(self.num_buffers) @@ -178,6 +180,15 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: msg_id = await read_int(reader) buf_idx = msg_id % self.num_buffers + # # Profiling indicates its faster to repeatedly + # # reconstruct <=512kB msgs from memory for up to + # # about 4 subs -- deepcopy is about 4x slower + # # which I suspect will be majority of cases + # # With more than 4 subs, one deepcopy may be faster + # # for <=512kB msgs, but becomes remarkably time + # # consuming for >= 512kB msgs. + # cache_msg = len(self.clients) > 4 + if msg == Command.TX_SHM.value: shm_name = await read_str(reader) @@ -191,14 +202,17 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: "Invalid SHM received from publisher; may be dead" ) raise + + # if cache_msg: + # ... elif msg == Command.TX_TCP.value: buf_size = await read_int(reader) obj_bytes = await reader.readexactly(buf_size) + self.tcp_cache[buf_idx] = memoryview(obj_bytes).toreadonly() - with MessageMarshal.obj_from_mem(memoryview(obj_bytes)) as obj: - self.cache[buf_idx] = obj - self.cache_id[buf_idx] = msg_id + # if cache_msg: + # ... if not self._notify_clients(msg_id): # Nobody is listening; need to ack! @@ -238,39 +252,28 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: """get object from cache; if not in cache and shm provided -- get from shm""" + dispatched = False + buf_idx = msg_id % self.num_buffers if self.cache_id[buf_idx] == msg_id: + dispatched = True yield self.cache[buf_idx] - else: - if self.shm is None: - raise CacheMiss - + elif self.shm is not None: with self.shm.buffer(buf_idx, readonly=True) as mem: - if MessageMarshal.msg_id(mem) != msg_id: - raise CacheMiss - - with MessageMarshal.obj_from_mem(mem) as obj: - # Could deepcopy and put in cache here... - # Profiling indicates its faster to repeatedly - # reconstruct <=512kB msgs from memory for up to - # about 4 subs -- deepcopy is about 4x slower - # which I suspect will be majority of cases - # With more than 4 subs, one deepcopy may be faster - # for <=512kB msgs, but becomes remarkably time - # consuming for >= 512kB msgs. - - # TODO: Implement a tuning method that runs on - # first ezmsg run that determines this tradeoff - # and decides fastest behavior; then decide - # to cache or not cache depending on this profiling - # for this particular machine. - # NOTE: We have information about the message size and - # the current fanout at time of RX, so we could - # intelligently copy here! - # self.shm.buf_size # Buffer size - # len(self.clients) # Channel fanout + if MessageMarshal.msg_id(mem) == msg_id: + with MessageMarshal.obj_from_mem(mem) as obj: + dispatched = True + yield obj + + if not dispatched: + tcp_mem = self.tcp_cache[buf_idx] + if MessageMarshal.msg_id(tcp_mem) == msg_id: + with MessageMarshal.obj_from_mem(tcp_mem) as obj: + dispatched = True yield obj + else: + raise CacheMiss self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index 2266044c..70f5a340 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -7,6 +7,7 @@ _PREAMBLE = b"EZ" _PREAMBLE_LEN = len(_PREAMBLE) +NO_MESSAGE = _PREAMBLE + (b'\xFF' * 8) + (b'\x00' * 8) class UndersizedMemory(Exception): @@ -71,6 +72,9 @@ def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: sidx = _PREAMBLE_LEN + UINT64_SIZE num_buffers = bytes_to_uint(mem[sidx : sidx + UINT64_SIZE]) + if num_buffers == 0: + raise ValueError("invalid message in memory") + sidx += UINT64_SIZE buf_sizes = [0] * num_buffers for i in range(num_buffers): From 856802fc434f1e048a152fa0d0ce4289f8f73474 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 13:09:54 -0400 Subject: [PATCH 62/88] added median latency --- src/ezmsg/util/perf/impl.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 114a8b7a..99d11046 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -41,6 +41,7 @@ class Metrics: num_msgs: int sample_rate: float latency_mean: float + latency_median: float latency_total: float data_rate: float @@ -274,12 +275,11 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: # Log some useful summary statistics min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) - total_latency = abs( - sum( - receive_timestamp - send_timestamp - for send_timestamp, receive_timestamp, _ in sink.STATE.received_data - ) - ) + latency = [ + receive_timestamp - send_timestamp + for send_timestamp, receive_timestamp, _ in sink.STATE.received_data + ] + total_latency = abs(sum(latency)) counters = list(sorted(t[2] for t in sink.STATE.received_data)) dropped_samples = sum( @@ -291,7 +291,9 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: sample_rate = num_samples / duration ez.logger.info(f"Sample rate: {sample_rate} Hz") latency_mean = total_latency / num_samples + latency_median = list(sorted(latency))[len(latency) // 2] ez.logger.info(f"Mean latency: {latency_mean} s") + ez.logger.info(f"Median latency: {latency_median} s") ez.logger.info(f"Total latency: {total_latency} s") total_data = num_samples * sink.SETTINGS.dynamic_size @@ -307,6 +309,7 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: num_msgs = num_samples, sample_rate = sample_rate, latency_mean = latency_mean, + latency_median = latency_median, latency_total = total_latency, data_rate = data_rate ) From a77eaa836dba023107a203dcd1c3d294e980e3ca Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 15:30:46 -0400 Subject: [PATCH 63/88] added median latency to performance report --- src/ezmsg/util/perf/analysis.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 681c9dac..ed729187 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -246,11 +246,11 @@ def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 5.0 delta = value - 100.0 # Determine direction: + is good for sample/data; - is good for latency - if metric in ("sample_rate", "data_rate"): + if 'rate' in metric: # positive delta good, negative bad magnitude = abs(delta) sign_good = delta > 0 - elif metric == "latency_mean": + elif 'latency' in metric: # negative delta good (lower latency) magnitude = abs(delta) sign_good = delta < 0 @@ -316,7 +316,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> if html: # Ensure expected columns exist - expected_cols = {"sample_rate", "data_rate", "latency_mean"} + expected_cols = {"sample_rate", "data_rate", "latency_mean", "latency_median"} missing = expected_cols - set(df.columns) if missing: raise ValueError(f"Missing expected columns in dataset: {missing}") @@ -358,7 +358,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> # Render each group for (config, comms), g in groups: # Keep only expected columns in order - cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean"] + cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean", "latency_median"] g = g[cols].copy() # String format some columns (msg_size with separators) @@ -374,17 +374,20 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> sample_rate {'' if relative else '(msgs/s)'} data_rate {'' if relative else '(MB/s)'} latency_mean {'' if relative else '(us)'} + latency_median {'' if relative else '(us)'} """ body_rows: list[str] = [] for _, row in g.iterrows(): - sr, dr, lt = row["sample_rate"], row["data_rate"], row["latency_mean"] + sr, dr, lt, lm = row["sample_rate"], row["data_rate"], row["latency_mean"], row["latency_median"] dr = dr if relative else dr / 2**20 lt = lt if relative else lt * 1e6 + lm = lm if relative else lm * 1e6 sr_style = _color_for_comparison(sr, "sample_rate") if relative else "" dr_style = _color_for_comparison(dr, "data_rate") if relative else "" lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" + lm_style = _color_for_comparison(lm, "latency_median") if relative else "" body_rows.append( "" @@ -393,6 +396,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> f"{_format_number(sr)}" f"{_format_number(dr)}" f"{_format_number(lt)}" + f"{_format_number(lm)}" "" ) table_html = f"{header}{''.join(body_rows)}
" From a0490d976cdec4c42a738886a8e09f5c4acfc05b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 23 Sep 2025 16:07:10 -0400 Subject: [PATCH 64/88] try using time.time --- src/ezmsg/util/perf/impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 99d11046..9fdfe740 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -75,16 +75,16 @@ async def initialize(self) -> None: @ez.publisher(OUTPUT) async def publish(self) -> typing.AsyncGenerator: ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.perf_counter() + start_time = time.time() while self.running: - current_time = time.perf_counter() + current_time = time.time() if current_time - start_time >= self.SETTINGS.duration: break yield ( self.OUTPUT, LoadTestSample( - _timestamp=time.perf_counter(), + _timestamp=time.time(), counter=self.counter, dynamic_data=np.zeros( int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 @@ -133,7 +133,7 @@ async def receive(self, sample: LoadTestSample) -> None: f"{sample.counter - counter - 1} samples skipped!" ) self.STATE.received_data.append( - (sample._timestamp, time.perf_counter(), sample.counter) + (sample._timestamp, time.time(), sample.counter) ) self.STATE.counters[sample.key] = sample.counter From dbd48b1b81a2f655700a05bcd60ac41dfcb6e00b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 28 Oct 2025 11:56:31 -0400 Subject: [PATCH 65/88] more diagnostic test results in less time --- src/ezmsg/util/perf/command.py | 3 +++ src/ezmsg/util/perf/run.py | 27 ++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index a100863d..ab919999 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -12,3 +12,6 @@ def command() -> None: ns = parser.parse_args() ns._handler(ns) + +if __name__ == "__main__": + command() diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 80bf80cb..855a94f2 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -3,6 +3,7 @@ import itertools import argparse import typing +import random from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo @@ -32,6 +33,7 @@ def perf_run( n_clients: typing.Iterable[int] | None, comms: typing.Iterable[str] | None, configs: typing.Iterable[str] | None, + grid: bool, ) -> None: if n_clients is None: @@ -45,6 +47,13 @@ def perf_run( if any(s < 0 for s in msg_sizes): ez.logger.error('All msg_sizes must be >=0 bytes') + if not grid and len(list(n_clients)) != len(list(msg_sizes)): + ez.logger.warning( + "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + \ + "len(n_clients) != len(msg_sizes). " + \ + "If you want to perform all combinations of n_clients and msg_sizes, use --grid" + ) + try: communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] except ValueError: @@ -56,8 +65,16 @@ def perf_run( except ValueError: ez.logger.error(f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}") return + + subitr = itertools.product if grid else zip + + test_list = [ + (clients, msg_size, config, comms) + for clients, msg_size in subitr(n_clients, msg_sizes) + for config, comms in itertools.product(configurators, communications) + ] - test_list = list(itertools.product(msg_sizes, n_clients, configurators, communications)) + random.shuffle(test_list) with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: @@ -151,6 +168,13 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: help = f"configurations to test (default = {[c for c in CONFIGS]})" ) + p_run.add_argument( + "--grid", + action = "store_true", + help = "perform all combinations of msg_sizes and n_clients " + \ + "(default: False; msg_sizes and n_clients must match in length)" + ) + p_run.set_defaults(_handler=lambda ns: perf_run( duration = ns.duration, @@ -160,4 +184,5 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: n_clients = ns.n_clients, comms = ns.comms, configs = ns.configs, + grid = ns.grid )) \ No newline at end of file From 7fcfaa864427427eaaeb3db025bd7902d5b56c3a Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 28 Oct 2025 16:25:10 -0400 Subject: [PATCH 66/88] early quit and test randomization --- src/ezmsg/util/perf/analysis.py | 21 +++--- src/ezmsg/util/perf/impl.py | 4 +- src/ezmsg/util/perf/run.py | 113 ++++++++++++++++++++++++-------- 3 files changed, 101 insertions(+), 37 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index ed729187..c8699ab9 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -59,21 +59,21 @@ def load_perf(perf: Path) -> xr.Dataset: - params: typing.List[TestParameters] = [] - results: typing.List[typing.List[Metrics]] = [] + all_results: typing.Dict[TestParameters, typing.List[Metrics]] = dict() with open(perf, 'r') as perf_f: info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) for line in perf_f: obj: TestLogEntry = json.loads(line, cls = MessageDecoder) - params.append(obj.params) - results.append(obj.results) + metrics = all_results.get(obj.params, list()) + metrics.append(obj.results) + all_results[obj.params] = metrics - n_clients_axis = list(sorted(set([p.n_clients for p in params]))) - msg_size_axis = list(sorted(set([p.msg_size for p in params]))) - comms_axis = list(sorted(set([p.comms for p in params]))) - config_axis = list(sorted(set([p.config for p in params]))) + n_clients_axis = list(sorted(set([p.n_clients for p in all_results.keys()]))) + msg_size_axis = list(sorted(set([p.msg_size for p in all_results.keys()]))) + comms_axis = list(sorted(set([p.comms for p in all_results.keys()]))) + config_axis = list(sorted(set([p.config for p in all_results.keys()]))) dims = ['n_clients', 'msg_size', 'comms', 'config'] coords = { @@ -91,7 +91,7 @@ def load_perf(perf: Path) -> xr.Dataset: len(comms_axis), len(config_axis) )) * np.nan - for p, r in zip(params, results): + for p, r in all_results.items(): # tests are run multiple times; get the median value for each metric values = list(sorted([getattr(v, field.name) for v in r])) value = values[len(values)//2] @@ -305,7 +305,10 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> # These raw stats are still valuable to have, but are confusing # when making relative comparisons perf = perf.drop_vars(['latency_total', 'num_msgs']) + perf = perf.stack(params = ['n_clients', 'msg_size']).dropna('params') df = perf.squeeze().to_dataframe() + df = df.drop('n_clients', axis = 1) + df = df.drop('msg_size', axis = 1) for _, config_ds in perf.groupby('config'): for _, comms_ds in config_ds.groupby('comms'): diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 9fdfe740..289dde8c 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -315,7 +315,7 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: ) -@dataclasses.dataclass +@dataclasses.dataclass(unsafe_hash=True) class TestParameters: msg_size: int n_clients: int @@ -328,4 +328,4 @@ class TestParameters: @dataclasses.dataclass class TestLogEntry: params: TestParameters - results: typing.List[Metrics] \ No newline at end of file + results: Metrics \ No newline at end of file diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 855a94f2..8ae5eda7 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -1,9 +1,14 @@ +import os +import sys import json import datetime import itertools import argparse import typing import random +import time + +from contextlib import contextmanager, redirect_stdout, redirect_stderr from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo @@ -21,10 +26,66 @@ DEFAULT_N_CLIENTS = [2 ** exp for exp in range(0, 6, 2)] DEFAULT_COMMS = [c for c in Communication] +# --- Output Suppression Context Manager (from the previous solution) --- +@contextmanager +def suppress_output(verbose: bool = False): + """Context manager to redirect stdout and stderr to os.devnull""" + if verbose: + yield + else: + # Open the null device for writing + with open(os.devnull, 'w') as fnull: + # Redirect both stdout and stderr to the null device + with redirect_stderr(fnull): + with redirect_stdout(fnull): + yield + +CHECK_FOR_QUIT = lambda: False + +if sys.platform.startswith('win'): + import msvcrt + def _check_for_quit_win() -> bool: + """ + Checks for the 'q' key press in a non-blocking way. + Returns True if 'q' is pressed (case-insensitive), False otherwise. + """ + # Windows: Use msvcrt for non-blocking keyboard hit detection + if msvcrt.kbhit(): # type: ignore + # Read the key press (returns bytes) + key = msvcrt.getch() # type: ignore + try: + # Decode and check for 'q' + return key.decode().lower() == 'q' + except UnicodeDecodeError: + # Handle potential non-text key presses gracefully + return False + return False + + CHECK_FOR_QUIT = _check_for_quit_win + +else: + import select + def _check_for_quit() -> bool: + """ + Checks for the 'q' key press in a non-blocking way. + Returns True if 'q' is pressed (case-insensitive), False otherwise. + """ + # Linux/macOS: Use select to check if stdin has data + # select.select(rlist, wlist, xlist, timeout) + # timeout=0 makes it non-blocking + if sys.stdin.isatty(): + i, o, e = select.select([sys.stdin], [], [], 0) # type: ignore + if i: + # Read the buffered character + key = sys.stdin.read(1) + return key.lower() == 'q' + return False + + CHECK_FOR_QUIT = _check_for_quit + def get_datestamp() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - def perf_run( duration: float, num_buffers: int, @@ -69,10 +130,10 @@ def perf_run( subitr = itertools.product if grid else zip test_list = [ - (clients, msg_size, config, comms) - for clients, msg_size in subitr(n_clients, msg_sizes) - for config, comms in itertools.product(configurators, communications) - ] + (msg_size, clients, conf, comm) + for msg_size, clients in subitr(msg_sizes, n_clients) + for conf, comm in itertools.product(configurators, communications) + ] * iters random.shuffle(test_list) @@ -80,33 +141,34 @@ def perf_run( out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + ez.logger.info("Starting perf tests. Press 'q' + enter to quit tests early.") + time.sleep(3.0) # Give user an opportunity to read message. + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): - ez.logger.info(f"RUNNING TEST {test_idx + 1} / {len(test_list)} ({(test_idx / len(test_list)) * 100.0:0.2f} %)") - - params = TestParameters( - msg_size = msg_size, - n_clients = clients, - config = conf.__name__, - comms = comm.value, - duration = duration, - num_buffers = num_buffers - ) - - results = [ - perform_test( + if CHECK_FOR_QUIT(): + ez.logger.info("Stopping tests early...") + break + + ez.logger.info(f"TEST {test_idx + 1}/{len(test_list)}: {clients=}, {msg_size=}, conf={conf.__name__}, comm={comm.value}") + + output = TestLogEntry( + params = TestParameters( + msg_size = msg_size, + n_clients = clients, + config = conf.__name__, + comms = comm.value, + duration = duration, + num_buffers = num_buffers + ), + results = perform_test( n_clients = clients, duration = duration, msg_size = msg_size, buffers = num_buffers, comms = comm, config = conf, - ) for _ in range(iters) - ] - - output = TestLogEntry( - params = params, - results = results + ) ) out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") @@ -175,7 +237,6 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: "(default: False; msg_sizes and n_clients must match in length)" ) - p_run.set_defaults(_handler=lambda ns: perf_run( duration = ns.duration, num_buffers = ns.num_buffers, @@ -184,5 +245,5 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: n_clients = ns.n_clients, comms = ns.comms, configs = ns.configs, - grid = ns.grid + grid = ns.grid, )) \ No newline at end of file From 2489c783e95686c75aba3d5833d8d4455274902d Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 5 Nov 2025 11:45:00 -0500 Subject: [PATCH 67/88] attempting to stabilize perf test results --- src/ezmsg/util/perf/analysis.py | 28 ++-- src/ezmsg/util/perf/impl.py | 40 +++-- src/ezmsg/util/perf/run.py | 111 +++++++----- src/ezmsg/util/perf/util.py | 288 ++++++++++++++++++++++++++++++++ 4 files changed, 399 insertions(+), 68 deletions(-) create mode 100644 src/ezmsg/util/perf/util.py diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index c8699ab9..7f731e6d 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -93,14 +93,12 @@ def load_perf(perf: Path) -> xr.Dataset: )) * np.nan for p, r in all_results.items(): # tests are run multiple times; get the median value for each metric - values = list(sorted([getattr(v, field.name) for v in r])) - value = values[len(values)//2] m[ n_clients_axis.index(p.n_clients), msg_size_axis.index(p.msg_size), comms_axis.index(p.comms), config_axis.index(p.config) - ] = value + ] = np.median([getattr(v, field.name) for v in r]) data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) dataset = xr.Dataset(data_vars, attrs = dict(info = info)) @@ -233,13 +231,13 @@ def _base_css() -> str: """ -def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 5.0) -> str: +def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 10.0) -> str: """ Returns inline CSS background for a comparison % value. value: e.g., 97.3, 104.8, etc. For sample_rate/data_rate: improvement > 100 (good). For latency_mean: improvement < 100 (good). - Noise band ±5% around 100 is neutral. + Noise band ±10% around 100 is neutral. """ if not (isinstance(value, (int, float)) and math.isfinite(value)): return "" @@ -302,9 +300,10 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> env_diff = format_env_diff(info.diff(baseline_info)) output += env_diff + '\n\n' - # These raw stats are still valuable to have, but are confusing - # when making relative comparisons - perf = perf.drop_vars(['latency_total', 'num_msgs']) + # These raw stats are still valuable to have, but are confusing + # when making relative comparisons + perf = perf.drop_vars(['latency_total', 'num_msgs']) + perf = perf.stack(params = ['n_clients', 'msg_size']).dropna('params') df = perf.squeeze().to_dataframe() df = df.drop('n_clients', axis = 1) @@ -319,7 +318,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> if html: # Ensure expected columns exist - expected_cols = {"sample_rate", "data_rate", "latency_mean", "latency_median"} + expected_cols = {"sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"} missing = expected_cols - set(df.columns) if missing: raise ValueError(f"Missing expected columns in dataset: {missing}") @@ -361,7 +360,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> # Render each group for (config, comms), g in groups: # Keep only expected columns in order - cols = ["n_clients", "msg_size", "sample_rate", "data_rate", "latency_mean", "latency_median"] + cols = ["n_clients", "msg_size", "sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"] g = g[cols].copy() # String format some columns (msg_size with separators) @@ -374,7 +373,8 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> n_clients msg_size {'' if relative else '(b)'} - sample_rate {'' if relative else '(msgs/s)'} + sample_rate_mean {'' if relative else '(msgs/s)'} + sample_rate_median {'' if relative else '(msgs/s)'} data_rate {'' if relative else '(MB/s)'} latency_mean {'' if relative else '(us)'} latency_median {'' if relative else '(us)'} @@ -383,11 +383,12 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> """ body_rows: list[str] = [] for _, row in g.iterrows(): - sr, dr, lt, lm = row["sample_rate"], row["data_rate"], row["latency_mean"], row["latency_median"] + sr, srm, dr, lt, lm = row["sample_rate_mean"], row["sample_rate_median"], row["data_rate"], row["latency_mean"], row["latency_median"] dr = dr if relative else dr / 2**20 lt = lt if relative else lt * 1e6 lm = lm if relative else lm * 1e6 - sr_style = _color_for_comparison(sr, "sample_rate") if relative else "" + sr_style = _color_for_comparison(sr, "sample_rate_mean") if relative else "" + srm_style = _color_for_comparison(srm, "sample_rate_median") if relative else "" dr_style = _color_for_comparison(dr, "data_rate") if relative else "" lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" lm_style = _color_for_comparison(lm, "latency_median") if relative else "" @@ -397,6 +398,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> f"{_format_number(row['n_clients'])}" f"{_escape(row['msg_size'])}" f"{_format_number(sr)}" + f"{_format_number(srm)}" f"{_format_number(dr)}" f"{_format_number(lt)}" f"{_format_number(lm)}" diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 289dde8c..4788cc2c 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -7,12 +7,16 @@ import ezmsg.core as ez +from ezmsg.core.netprotocol import Address + try: import numpy as np except ImportError: ez.logger.error("ezmsg perf requires numpy") raise +from .util import stable_perf + def collect( components: typing.Optional[typing.Mapping[str, ez.Component]] = None, @@ -39,7 +43,8 @@ def collect( @dataclasses.dataclass class Metrics: num_msgs: int - sample_rate: float + sample_rate_mean: float + sample_rate_median: float latency_mean: float latency_median: float latency_total: float @@ -214,7 +219,8 @@ def perform_test( msg_size: int, buffers: int, comms: Communication, - config: Configurator + config: Configurator, + graph_address: Address ) -> Metrics: settings = LoadTestSettings( @@ -261,11 +267,13 @@ def perform_test( components["PROC"] = proc_collection process_components = [proc_collection] - ez.run( - components = components, - connections = connections, - process_components = process_components, - ) + with stable_perf(): + ez.run( + components = components, + connections = connections, + process_components = process_components, + graph_address = graph_address + ) return calculate_metrics(sink, duration) @@ -286,18 +294,21 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: [max((x1 - x0) - 1, 0) for x1, x0 in zip(counters[1:], counters[:-1])] ) + rx_timestamps = np.array([rx_ts for _, rx_ts, _ in sink.STATE.received_data]) num_samples = len(sink.STATE.received_data) - ez.logger.info(f"Samples received: {num_samples}") - sample_rate = num_samples / duration - ez.logger.info(f"Sample rate: {sample_rate} Hz") + samplerate_mean = num_samples / duration + samplerate_median = 1.0 / float(np.median(np.diff(rx_timestamps))) latency_mean = total_latency / num_samples latency_median = list(sorted(latency))[len(latency) // 2] + total_data = num_samples * sink.SETTINGS.dynamic_size + data_rate = total_data / (max_timestamp - min_timestamp) + + ez.logger.info(f"Samples received: {num_samples}") + ez.logger.info(f"Mean sample rate: {samplerate_mean} Hz") + ez.logger.info(f"Median sample rate: {samplerate_median} Hz") ez.logger.info(f"Mean latency: {latency_mean} s") ez.logger.info(f"Median latency: {latency_median} s") ez.logger.info(f"Total latency: {total_latency} s") - - total_data = num_samples * sink.SETTINGS.dynamic_size - data_rate = total_data / (max_timestamp - min_timestamp) ez.logger.info(f"Data rate: {data_rate * 1e-6} MB/s") if dropped_samples: @@ -307,7 +318,8 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: return Metrics( num_msgs = num_samples, - sample_rate = sample_rate, + sample_rate_mean = samplerate_mean, + sample_rate_median = samplerate_median, latency_mean = latency_mean, latency_median = latency_median, latency_total = total_latency, diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 8ae5eda7..c23913f2 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -6,12 +6,16 @@ import argparse import typing import random -import time +from datetime import datetime, timedelta from contextlib import contextmanager, redirect_stdout, redirect_stderr +import ezmsg.core as ez +from ezmsg.core.graphserver import GraphServer + from ..messagecodec import MessageEncoder from .envinfo import TestEnvironmentInfo +from .util import warmup from .impl import ( TestParameters, TestLogEntry, @@ -20,8 +24,6 @@ CONFIGS, ) -import ezmsg.core as ez - DEFAULT_MSG_SIZES = [2 ** exp for exp in range(4, 25, 8)] DEFAULT_N_CLIENTS = [2 ** exp for exp in range(0, 6, 2)] DEFAULT_COMMS = [c for c in Communication] @@ -84,7 +86,7 @@ def _check_for_quit() -> bool: CHECK_FOR_QUIT = _check_for_quit def get_datestamp() -> str: - return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + return datetime.now().strftime("%Y%m%d_%H%M%S") def perf_run( duration: float, @@ -95,6 +97,7 @@ def perf_run( comms: typing.Iterable[str] | None, configs: typing.Iterable[str] | None, grid: bool, + warmup_dur: float, ) -> None: if n_clients is None: @@ -137,41 +140,59 @@ def perf_run( random.shuffle(test_list) - with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: - - out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") - - ez.logger.info("Starting perf tests. Press 'q' + enter to quit tests early.") - time.sleep(3.0) # Give user an opportunity to read message. - - for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): - - if CHECK_FOR_QUIT(): - ez.logger.info("Stopping tests early...") - break - - ez.logger.info(f"TEST {test_idx + 1}/{len(test_list)}: {clients=}, {msg_size=}, conf={conf.__name__}, comm={comm.value}") - - output = TestLogEntry( - params = TestParameters( - msg_size = msg_size, - n_clients = clients, - config = conf.__name__, - comms = comm.value, - duration = duration, - num_buffers = num_buffers - ), - results = perform_test( - n_clients = clients, - duration = duration, - msg_size = msg_size, - buffers = num_buffers, - comms = comm, - config = conf, + server = GraphServer() + server.start() + + d = datetime(1,1,1) + timedelta(seconds = len(test_list) * duration) + total_dur_str = ':'.join([str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0]) + ez.logger.info(f"About to run {len(test_list)} tests of {duration} sec each.") + ez.logger.info(f"Expected total duration ~{total_dur_str})") + ez.logger.info(f"Please try to avoid running other taxing software while this perf test runs.") + ez.logger.info(f"NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early.") + + try: + ez.logger.info(f"Warming up for {warmup_dur} seconds...") + warmup(warmup_dur) + + with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: + out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): + + if CHECK_FOR_QUIT(): + ez.logger.info("Stopping tests early...") + break + + ez.logger.info( + f"TEST {test_idx + 1}/{len(test_list)}: " \ + f"{clients=}, {msg_size=}, conf={conf.__name__}, " \ + f"comm={comm.value}" + ) + + output = TestLogEntry( + params = TestParameters( + msg_size = msg_size, + n_clients = clients, + config = conf.__name__, + comms = comm.value, + duration = duration, + num_buffers = num_buffers + ), + results = perform_test( + n_clients = clients, + duration = duration, + msg_size = msg_size, + buffers = num_buffers, + comms = comm, + config = conf, + graph_address = server.address + ) ) - ) - out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + finally: + server.stop() + def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: @@ -180,8 +201,8 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: p_run.add_argument( "--duration", type=float, - default=2.0, - help="individual test duration in seconds (default = 2.0)", + default=0.5, + help="individual test duration in seconds (default = 0.5)", ) p_run.add_argument( @@ -194,8 +215,8 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: p_run.add_argument( "--iters", "-i", type = int, - default = 3, - help = "number of times to run each test (default = 3)" + default = 50, + help = "number of times to run each test (default = 50)" ) p_run.add_argument( @@ -237,6 +258,13 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: "(default: False; msg_sizes and n_clients must match in length)" ) + p_run.add_argument( + "--warmup", + type = float, + default = 60.0, + help = "warmup CPU with busy task for some number of seconds (default = 60.0)" + ) + p_run.set_defaults(_handler=lambda ns: perf_run( duration = ns.duration, num_buffers = ns.num_buffers, @@ -246,4 +274,5 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: comms = ns.comms, configs = ns.configs, grid = ns.grid, + warmup_dur = ns.warmup, )) \ No newline at end of file diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py new file mode 100644 index 00000000..d22d1496 --- /dev/null +++ b/src/ezmsg/util/perf/util.py @@ -0,0 +1,288 @@ +import os +import sys +import gc +import time +import statistics as stats +import contextlib +import subprocess +import platform +from dataclasses import dataclass +from typing import Iterable, List, Optional + +try: + import psutil # optional but helpful +except Exception: + psutil = None + +_IS_WIN = os.name == "nt" +_IS_MAC = sys.platform == "darwin" +_IS_LINUX = sys.platform.startswith("linux") + +# ---------- Utilities ---------- + +def _set_env_threads(single_thread: bool = True): + """ + Normalize math/threading libs so they don't spawn surprise worker threads. + """ + if single_thread: + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1") + os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") + os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") + # Keep PYTHONHASHSEED stable for deterministic dict/set iteration costs + os.environ.setdefault("PYTHONHASHSEED", "0") + +# ---------- Priority & Affinity ---------- + +@contextlib.contextmanager +def _process_priority(): + """ + Elevate process priority in a cross-platform best-effort way. + """ + if psutil is None: + yield + return + + p = psutil.Process() + orig_nice = None + if _IS_WIN: + try: + import ctypes, ctypes.wintypes as wt + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + ABOVE_NORMAL_PRIORITY_CLASS = 0x00008000 + HIGH_PRIORITY_CLASS = 0x00000080 + # Try High, fall back to Above Normal + if not kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), HIGH_PRIORITY_CLASS): + kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), ABOVE_NORMAL_PRIORITY_CLASS) + except Exception: + pass + else: + try: + orig_nice = p.nice() + # Negative nice may need privileges; try smaller magnitude first + for nice_val in (-10, -5, 0): + try: + p.nice(nice_val) + break + except Exception: + continue + except Exception: + pass + try: + yield + finally: + # restore nice if we changed it + if psutil is not None and not _IS_WIN and orig_nice is not None: + try: + p.nice(orig_nice) + except Exception: + pass + +@contextlib.contextmanager +def _cpu_affinity(prefer_isolation: bool = True): + """ + Set CPU affinity to a small, stable set of CPUs (where supported). + macOS does not support affinity via psutil; we no-op there. + """ + if psutil is None or _IS_MAC: + yield + return + + p = psutil.Process() + original = None + try: + if hasattr(p, "cpu_affinity"): + original = p.cpu_affinity() + cpus = original + if prefer_isolation and len(cpus) > 2: + # Pick two middle CPUs to avoid 0 which often handles interrupts + mid = len(cpus) // 2 + cpus = [cpus[mid-1], cpus[mid]] + p.cpu_affinity(cpus) + yield + finally: + try: + if original is not None and hasattr(p, "cpu_affinity"): + p.cpu_affinity(original) + except Exception: + pass + +# ---------- Platform-specific helpers ---------- + +@contextlib.contextmanager +def _mac_caffeinate(): + """ + Keep macOS awake during the run via a background caffeinate process. + """ + if not _IS_MAC: + yield + return + proc = None + try: + proc = subprocess.Popen(["caffeinate", "-dimsu"]) + except Exception: + proc = None + try: + yield + finally: + if proc is not None: + try: + proc.terminate() + except Exception: + pass + +@contextlib.contextmanager +def _win_timer_resolution(ms: int = 1): + """ + On Windows, request a finer system timer to stabilize sleeps and scheduling slices. + """ + if not _IS_WIN: + yield + return + import ctypes + winmm = ctypes.WinDLL("winmm") + timeBeginPeriod = winmm.timeBeginPeriod + timeEndPeriod = winmm.timeEndPeriod + try: + timeBeginPeriod(ms) + except Exception: + pass + try: + yield + finally: + try: + timeEndPeriod(ms) + except Exception: + pass + +# ---------- Warm-up & GC ---------- + +def warmup(seconds: float = 60.0, fn=None, *args, **kwargs): + """ + Optional warm-up to reach steady clocks/caches. + If fn is provided, call it in a loop for the given time. + """ + if seconds <= 0: + return + end = time.perf_counter() + target = end + seconds + if fn is None: + # Busy wait / sleep mix to heat up without heavy CPU + while time.perf_counter() < target: + x = 0 + for _ in range(10000): + x += 1 + time.sleep(0) + else: + while time.perf_counter() < target: + fn(*args, **kwargs) + +@contextlib.contextmanager +def gc_pause(): + """ + Disable GC inside timing windows; re-enable and collect after. + """ + was_enabled = gc.isenabled() + try: + gc.disable() + yield + finally: + if was_enabled: + gc.enable() + gc.collect() + +# ---------- Robust statistics ---------- + +def median_of_means(samples: Iterable[float], k: int = 5) -> float: + """ + Robust estimate: split samples into k buckets (round-robin), average each, take median of bucket means. + """ + samples = list(samples) + if not samples: + return float("nan") + k = max(1, min(k, len(samples))) + buckets = [[] for _ in range(k)] + for i, v in enumerate(samples): + buckets[i % k].append(v) + means = [sum(b)/len(b) for b in buckets if b] + means.sort() + return means[len(means)//2] + +def coef_var(samples: Iterable[float]) -> float: + vals = list(samples) + if len(vals) < 2: + return 0.0 + m = sum(vals)/len(vals) + if m == 0: + return 0.0 + sd = stats.pstdev(vals) + return sd / m + +# ---------- Public context manager ---------- + +@dataclass +class PerfOptions: + single_thread_math: bool = True + prefer_isolated_cpus: bool = True + warmup_seconds: float = 0.0 + adjust_priority: bool = True + tweak_timer_windows: bool = True + keep_mac_awake: bool = True + +@contextlib.contextmanager +def stable_perf(opts: PerfOptions = PerfOptions()): + """ + Wrap your perf runs with this context manager for a stabler environment. + """ + _set_env_threads(opts.single_thread_math) + + cm_stack = contextlib.ExitStack() + try: + if opts.adjust_priority: + cm_stack.enter_context(_process_priority()) + if opts.tweak_timer_windows: + cm_stack.enter_context(_win_timer_resolution(1)) + if opts.prefer_isolated_cpus: + cm_stack.enter_context(_cpu_affinity(True)) + if opts.keep_mac_awake: + cm_stack.enter_context(_mac_caffeinate()) + + if opts.warmup_seconds > 0: + warmup(opts.warmup_seconds) + + with gc_pause(): + yield + finally: + cm_stack.close() + +# ---------- Example runners ---------- + +def run_interleaved(configs: List[dict], run_fn, trials: int = 5, trial_seconds: float = 30.0, seed: int = 42): + """ + Interleave scenarios to cancel slow drift. `run_fn(config, seconds, seed_offset)` should return a dict of metrics. + Returns a list of per-trial results per config. + """ + import random + random.seed(seed) + results = [ [] for _ in configs ] + order = list(range(len(configs))) + # fixed order per pass; you'll re-use the same order for stability + for t in range(trials): + for idx in order: + res = run_fn(configs[idx], trial_seconds, seed + t) + results[idx].append(res) + return results + +def summarize_metric(trials: List[dict], key: str, mom_buckets: int = 5): + """ + Extract a metric across trial dicts and summarize with median-of-means and CV. + """ + vals = [float(tr[key]) for tr in trials if key in tr] + return { + "count": len(vals), + "mom": median_of_means(vals, mom_buckets), + "mean": sum(vals)/len(vals) if vals else float("nan"), + "p50": stats.median(vals) if vals else float("nan"), + "cv": coef_var(vals) if vals else float("nan"), + } \ No newline at end of file From f3e2560c625ca63b49efcda945fcf099109695a8 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 6 Nov 2025 09:51:06 -0500 Subject: [PATCH 68/88] tests stabilized with median of means and num-buffers = 1 --- src/ezmsg/util/perf/analysis.py | 24 ++++--- src/ezmsg/util/perf/impl.py | 60 +++++++++++----- src/ezmsg/util/perf/run.py | 120 ++++++++++++++++++++------------ 3 files changed, 133 insertions(+), 71 deletions(-) diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 7f731e6d..2df4bebb 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -59,16 +59,22 @@ def load_perf(perf: Path) -> xr.Dataset: - all_results: typing.Dict[TestParameters, typing.List[Metrics]] = dict() + all_results: typing.Dict[TestParameters, typing.Dict[int, typing.List[Metrics]]] = dict() + + run_idx = 0 with open(perf, 'r') as perf_f: info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) - for line in perf_f: - obj: TestLogEntry = json.loads(line, cls = MessageDecoder) - metrics = all_results.get(obj.params, list()) - metrics.append(obj.results) - all_results[obj.params] = metrics + obj = json.loads(line, cls = MessageDecoder) + if isinstance(obj, TestEnvironmentInfo): + run_idx += 1 + elif isinstance(obj, TestLogEntry): + runs = all_results.get(obj.params, dict()) + metrics = runs.get(run_idx, list()) + metrics.append(obj.results) + runs[run_idx] = metrics + all_results[obj.params] = runs n_clients_axis = list(sorted(set([p.n_clients for p in all_results.keys()]))) msg_size_axis = list(sorted(set([p.msg_size for p in all_results.keys()]))) @@ -91,14 +97,14 @@ def load_perf(perf: Path) -> xr.Dataset: len(comms_axis), len(config_axis) )) * np.nan - for p, r in all_results.items(): - # tests are run multiple times; get the median value for each metric + for p, a in all_results.items(): + # tests are run multiple times; get the median of means m[ n_clients_axis.index(p.n_clients), msg_size_axis.index(p.msg_size), comms_axis.index(p.comms), config_axis.index(p.config) - ] = np.median([getattr(v, field.name) for v in r]) + ] = np.median([np.mean([getattr(v, field.name) for v in r]) for r in a.values()]) data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) dataset = xr.Dataset(data_vars, attrs = dict(info = info)) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 4788cc2c..a5eb18c2 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -7,6 +7,7 @@ import ezmsg.core as ez +from ezmsg.util.messages.util import replace from ezmsg.core.netprotocol import Address try: @@ -52,7 +53,8 @@ class Metrics: class LoadTestSettings(ez.Settings): - duration: float + max_duration: float + num_msgs: int dynamic_size: int buffers: int force_tcp: bool @@ -66,7 +68,7 @@ class LoadTestSample: key: str -class LoadTestSender(ez.Unit): +class LoadTestSource(ez.Unit): OUTPUT = ez.OutputStream(LoadTestSample) SETTINGS = LoadTestSettings @@ -81,9 +83,12 @@ async def initialize(self) -> None: async def publish(self) -> typing.AsyncGenerator: ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") start_time = time.time() - while self.running: + for _ in range(self.SETTINGS.num_msgs): + if not self.running: + break + current_time = time.time() - if current_time - start_time >= self.SETTINGS.duration: + if current_time - start_time >= self.SETTINGS.max_duration: break yield ( @@ -92,21 +97,22 @@ async def publish(self) -> typing.AsyncGenerator: _timestamp=time.time(), counter=self.counter, dynamic_data=np.zeros( - int(self.SETTINGS.dynamic_size // 8), dtype=np.float32 + int(self.SETTINGS.dynamic_size // 4), dtype=np.float32 ), key = self.name, ), ) self.counter += 1 + ez.logger.info("Exiting publish") raise ez.Complete - -class LoadTestSource(LoadTestSender): + async def shutdown(self) -> None: self.running = False ez.logger.info(f"Samples sent: {self.counter}") + class LoadTestRelay(ez.Unit): INPUT = ez.InputStream(LoadTestSample) OUTPUT = ez.OutputStream(LoadTestSample) @@ -130,6 +136,9 @@ class LoadTestReceiver(ez.Unit): SETTINGS = LoadTestSettings STATE = LoadTestReceiverState + async def initialize(self) -> None: + ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") + @ez.subscriber(INPUT, zero_copy=True) async def receive(self, sample: LoadTestSample) -> None: counter = self.STATE.counters.get(sample.key, -1) @@ -145,12 +154,19 @@ async def receive(self, sample: LoadTestSample) -> None: class LoadTestSink(LoadTestReceiver): + INPUT = ez.InputStream(LoadTestSample) + + @ez.subscriber(INPUT, zero_copy=True) + async def receive(self, sample: LoadTestSample) -> None: + await super().receive(sample) + if len(self.STATE.received_data) == self.SETTINGS.num_msgs: + raise ez.NormalTermination + @ez.task async def terminate(self) -> None: - ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") - - # Wait for the duration of the load test - await asyncio.sleep(self.SETTINGS.duration) + # Wait for the max duration of the load test + await asyncio.sleep(self.SETTINGS.max_duration) + ez.logger.warning("TIMEOUT -- terminating test.") raise ez.NormalTermination @@ -178,7 +194,9 @@ def fanout(config: ConfigSettings) -> Configuration: def fanin(config: ConfigSettings) -> Configuration: """ many pubs to one sub """ connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] - pubs = [LoadTestSender(config.settings) for _ in range(config.n_clients)] + pubs = [LoadTestSource(config.settings) for _ in range(config.n_clients)] + expected_num_msgs = config.sink.SETTINGS.num_msgs * len(pubs) + config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs = expected_num_msgs) # type: ignore for pub in pubs: connections.append((pub.OUTPUT, config.sink.INPUT)) return pubs, connections @@ -215,7 +233,8 @@ class Communication(enum.StrEnum): def perform_test( n_clients: int, - duration: float, + max_duration: float, + num_msgs: int, msg_size: int, buffers: int, comms: Communication, @@ -225,7 +244,8 @@ def perform_test( settings = LoadTestSettings( dynamic_size = int(msg_size), - duration = duration, + num_msgs = num_msgs, + max_duration = max_duration, buffers = buffers, force_tcp = (comms in (Communication.TCP, Communication.TCP_SPREAD)), ) @@ -275,10 +295,10 @@ def perform_test( graph_address = graph_address ) - return calculate_metrics(sink, duration) + return calculate_metrics(sink) -def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: +def calculate_metrics(sink: LoadTestSink) -> Metrics: # Log some useful summary statistics min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) @@ -295,13 +315,14 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: ) rx_timestamps = np.array([rx_ts for _, rx_ts, _ in sink.STATE.received_data]) + runtime = max_timestamp - min_timestamp num_samples = len(sink.STATE.received_data) - samplerate_mean = num_samples / duration + samplerate_mean = num_samples / runtime samplerate_median = 1.0 / float(np.median(np.diff(rx_timestamps))) latency_mean = total_latency / num_samples latency_median = list(sorted(latency))[len(latency) // 2] total_data = num_samples * sink.SETTINGS.dynamic_size - data_rate = total_data / (max_timestamp - min_timestamp) + data_rate = total_data / runtime ez.logger.info(f"Samples received: {num_samples}") ez.logger.info(f"Mean sample rate: {samplerate_mean} Hz") @@ -330,10 +351,11 @@ def calculate_metrics(sink: LoadTestSink, duration: float) -> Metrics: @dataclasses.dataclass(unsafe_hash=True) class TestParameters: msg_size: int + num_msgs: int n_clients: int config: str comms: str - duration: float + max_duration: float num_buffers: int diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index c23913f2..3c391bdc 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -6,6 +6,7 @@ import argparse import typing import random +import time from datetime import datetime, timedelta from contextlib import contextmanager, redirect_stdout, redirect_stderr @@ -89,9 +90,11 @@ def get_datestamp() -> str: return datetime.now().strftime("%Y%m%d_%H%M%S") def perf_run( - duration: float, + max_duration: float, + num_msgs: int, num_buffers: int, iters: int, + repeats: int, msg_sizes: typing.Iterable[int] | None, n_clients: typing.Iterable[int] | None, comms: typing.Iterable[str] | None, @@ -143,55 +146,70 @@ def perf_run( server = GraphServer() server.start() - d = datetime(1,1,1) + timedelta(seconds = len(test_list) * duration) - total_dur_str = ':'.join([str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0]) - ez.logger.info(f"About to run {len(test_list)} tests of {duration} sec each.") - ez.logger.info(f"Expected total duration ~{total_dur_str})") + ez.logger.info(f"About to run {len(test_list)} tests (repeated {repeats} times) of {max_duration} sec (max) each.") + ez.logger.info(f"During each test, source will attempt to send {num_msgs} messages to the sink.") ez.logger.info(f"Please try to avoid running other taxing software while this perf test runs.") ez.logger.info(f"NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early.") + quitting = False + + start_time = time.time() + try: ez.logger.info(f"Warming up for {warmup_dur} seconds...") warmup(warmup_dur) with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: - out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + for _ in range(repeats): + out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") - for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): + for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): - if CHECK_FOR_QUIT(): - ez.logger.info("Stopping tests early...") - break + if CHECK_FOR_QUIT(): + ez.logger.info("Stopping tests early...") + quitting = True + break + + ez.logger.info( + f"TEST {test_idx + 1}/{len(test_list)}: " \ + f"{clients=}, {msg_size=}, conf={conf.__name__}, " \ + f"comm={comm.value}" + ) - ez.logger.info( - f"TEST {test_idx + 1}/{len(test_list)}: " \ - f"{clients=}, {msg_size=}, conf={conf.__name__}, " \ - f"comm={comm.value}" - ) - - output = TestLogEntry( - params = TestParameters( - msg_size = msg_size, - n_clients = clients, - config = conf.__name__, - comms = comm.value, - duration = duration, - num_buffers = num_buffers - ), - results = perform_test( - n_clients = clients, - duration = duration, - msg_size = msg_size, - buffers = num_buffers, - comms = comm, - config = conf, - graph_address = server.address + output = TestLogEntry( + params = TestParameters( + msg_size = msg_size, + num_msgs = num_msgs, + n_clients = clients, + config = conf.__name__, + comms = comm.value, + max_duration = max_duration, + num_buffers = num_buffers + ), + results = perform_test( + n_clients = clients, + max_duration = max_duration, + num_msgs = num_msgs, + msg_size = msg_size, + buffers = num_buffers, + comms = comm, + config = conf, + graph_address = server.address + ) ) - ) - out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + + if quitting: + break + finally: server.stop() + d = datetime(1,1,1) + timedelta(seconds = time.time() - start_time) + dur_str = ':'.join([str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0]) + ez.logger.info(f"Tests concluded. Wallclock Runtime: {dur_str}s") + + def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: @@ -199,24 +217,38 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: p_run = subparsers.add_parser("run", help="run performance test") p_run.add_argument( - "--duration", + "--max-duration", type=float, - default=0.5, - help="individual test duration in seconds (default = 0.5)", + default=5.0, + help="maximum individual test duration in seconds (default = 5.0)", + ) + + p_run.add_argument( + "--num-msgs", + type=int, + default=1000, + help = "number of messages to send per-test (default = 1000)" ) p_run.add_argument( "--num-buffers", type=int, - default=32, - help="shared memory buffers (default = 32)", + default=1, + help="shared memory buffers (default = 1)", ) p_run.add_argument( "--iters", "-i", type = int, - default = 50, - help = "number of times to run each test (default = 50)" + default = 5, + help = "number of times to run each test (default = 5)" + ) + + p_run.add_argument( + "--repeats", "-r", + type = int, + default = 10, + help = "number of times to repeat the perf (default = 10)" ) p_run.add_argument( @@ -266,9 +298,11 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: ) p_run.set_defaults(_handler=lambda ns: perf_run( - duration = ns.duration, + max_duration = ns.max_duration, + num_msgs = ns.num_msgs, num_buffers = ns.num_buffers, iters = ns.iters, + repeats = ns.repeats, msg_sizes = ns.msg_sizes, n_clients = ns.n_clients, comms = ns.comms, From c6b985f4a612963194478530344cd56d4242af71 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 10 Nov 2025 11:59:18 -0500 Subject: [PATCH 69/88] further simplifying tests --- src/ezmsg/util/perf/run.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index 3c391bdc..ce6c2071 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -25,11 +25,11 @@ CONFIGS, ) -DEFAULT_MSG_SIZES = [2 ** exp for exp in range(4, 25, 8)] -DEFAULT_N_CLIENTS = [2 ** exp for exp in range(0, 6, 2)] +DEFAULT_MSG_SIZES = [2**4, 2**20] +DEFAULT_N_CLIENTS = [1, 16] DEFAULT_COMMS = [c for c in Communication] -# --- Output Suppression Context Manager (from the previous solution) --- +# --- Output Suppression Context Manager --- @contextmanager def suppress_output(verbose: bool = False): """Context manager to redirect stdout and stderr to os.devnull""" @@ -95,8 +95,8 @@ def perf_run( num_buffers: int, iters: int, repeats: int, - msg_sizes: typing.Iterable[int] | None, - n_clients: typing.Iterable[int] | None, + msg_sizes: typing.List[int] | None, + n_clients: typing.List[int] | None, comms: typing.Iterable[str] | None, configs: typing.Iterable[str] | None, grid: bool, @@ -117,8 +117,7 @@ def perf_run( if not grid and len(list(n_clients)) != len(list(msg_sizes)): ez.logger.warning( "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + \ - "len(n_clients) != len(msg_sizes). " + \ - "If you want to perform all combinations of n_clients and msg_sizes, use --grid" + f"{len(n_clients)=} which is not equal to {len(msg_sizes)=}. " ) try: @@ -230,6 +229,19 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: help = "number of messages to send per-test (default = 1000)" ) + # NOTE: We default num-buffers = 1 because this degenerate perf test scenario (blasting + # messages as fast as possible through the system) results in one of two scenerios: + # 1. A (few) messages is/are enqueued and dequeued before another message is posted + # 2. The buffer fills up before being FULLY emptied resulting in longer latency. + # (once a channel enters this condition, it tends to stay in this condition) + # + # This _indeterminate_ behavior results in bimodal distributions of runtimes that make + # A/B performance comparisons difficult. The perf test is not representative of the vast + # majority of production ezmsg systems where publishing is generally rate-limited. + # + # A flow-control algorithm could stabilize perf-test results with num_buffers > 1, but is + # generally implemented by enforcing delays on the publish side which simply degrades + # performance in the vast majority of ezmsg systems. - Griff p_run.add_argument( "--num-buffers", type=int, @@ -283,13 +295,6 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: help = f"configurations to test (default = {[c for c in CONFIGS]})" ) - p_run.add_argument( - "--grid", - action = "store_true", - help = "perform all combinations of msg_sizes and n_clients " + \ - "(default: False; msg_sizes and n_clients must match in length)" - ) - p_run.add_argument( "--warmup", type = float, @@ -307,6 +312,6 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: n_clients = ns.n_clients, comms = ns.comms, configs = ns.configs, - grid = ns.grid, + grid = True, warmup_dur = ns.warmup, )) \ No newline at end of file From b676e77f3ead6045b6812fbb8805fd7f5b2e1860 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 11 Nov 2025 12:50:59 -0500 Subject: [PATCH 70/88] bugfix: shm race --- src/ezmsg/core/graphserver.py | 2 +- src/ezmsg/core/messagechannel.py | 126 ++++++++++++++++++------------- src/ezmsg/core/server.py | 4 +- src/ezmsg/core/stream.py | 2 +- src/ezmsg/core/subclient.py | 4 +- 5 files changed, 80 insertions(+), 58 deletions(-) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 5bb5a519..1c4cb8be 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -45,7 +45,7 @@ class GraphServer(ThreadedAsyncServer): _command_lock: asyncio.Lock def __init__(self) -> None: - super().__init__() + super().__init__(name = "GraphServer") self.graph = DAG() self.clients = dict() self._client_tasks = dict() diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 66b8c5a0..c08d3cbf 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -4,10 +4,11 @@ import logging from uuid import UUID -from contextlib import contextmanager, suppress +from copy import deepcopy +from contextlib import contextmanager, suppress, ExitStack from .shm import SHMContext -from .messagemarshal import MessageMarshal, NO_MESSAGE +from .messagemarshal import MessageMarshal from .backpressure import Backpressure from .graphserver import GraphService @@ -41,7 +42,7 @@ class Channel: topic: str num_buffers: int - tcp_cache: typing.List[memoryview] + contexts: typing.List[typing.ContextManager | None] cache: typing.List[typing.Any] cache_id: typing.List[int | None] shm: SHMContext | None @@ -67,7 +68,7 @@ def __init__( self.num_buffers = num_buffers self.shm = shm - self.tcp_cache = [memoryview(NO_MESSAGE)] * self.num_buffers + self.contexts = [None] * self.num_buffers self.cache_id = [None] * self.num_buffers self.cache = [None] * self.num_buffers self.backpressure = Backpressure(self.num_buffers) @@ -136,14 +137,10 @@ async def create( return chan def close(self) -> None: - if self.shm is not None: - self.shm.close() self._graph_task.cancel() self._pub_task.cancel() async def wait_closed(self) -> None: - if self.shm is not None: - await self.shm.wait_closed() with suppress(asyncio.CancelledError): await self._graph_task with suppress(asyncio.CancelledError): @@ -179,20 +176,21 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: msg_id = await read_int(reader) buf_idx = msg_id % self.num_buffers - - # # Profiling indicates its faster to repeatedly - # # reconstruct <=512kB msgs from memory for up to - # # about 4 subs -- deepcopy is about 4x slower - # # which I suspect will be majority of cases - # # With more than 4 subs, one deepcopy may be faster - # # for <=512kB msgs, but becomes remarkably time - # # consuming for >= 512kB msgs. - # cache_msg = len(self.clients) > 4 + + ctx: typing.ContextManager | None = None + value: typing.Any = None if msg == Command.TX_SHM.value: shm_name = await read_str(reader) if self.shm is not None and self.shm.name != shm_name: + + shm_entries = [] + for i, c in enumerate(self.contexts): + if isinstance(c, ExitStack): + shm_entries.append(i) + c.close() + self.shm.close() await self.shm.wait_closed() try: @@ -202,27 +200,60 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: "Invalid SHM received from publisher; may be dead" ) raise + + for i in shm_entries: + stack = ExitStack() + view = stack.enter_context(self.shm.buffer(i)) + assert MessageMarshal.msg_id(view) == self.cache_id[i] + self.cache[i] = stack.enter_context(MessageMarshal.obj_from_mem(view)) + self.contexts[i] = stack - # if cache_msg: - # ... + assert self.shm is not None + + ctx = ExitStack() + view = ctx.enter_context(self.shm.buffer(buf_idx)) + assert MessageMarshal.msg_id(view) == msg_id + value = ctx.enter_context(MessageMarshal.obj_from_mem(view)) elif msg == Command.TX_TCP.value: buf_size = await read_int(reader) obj_bytes = await reader.readexactly(buf_size) - self.tcp_cache[buf_idx] = memoryview(obj_bytes).toreadonly() + view = memoryview(obj_bytes).toreadonly() + ctx = MessageMarshal.obj_from_mem(view) + value = ctx.__enter__() - # if cache_msg: - # ... + else: + raise ValueError(f"unimplemented data telemetry: {msg}") - if not self._notify_clients(msg_id): + if self._notify_clients(msg_id): + self.cache_id[buf_idx] = msg_id + self.cache[buf_idx] = value + self.contexts[buf_idx] = ctx + else: # Nobody is listening; need to ack! self._acknowledge(msg_id) + ctx.__exit__(None, None, None) except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") finally: await close_stream_writer(self._pub_writer) + + for i in range(self.num_buffers): + obj = self.cache[i] + self.cache[i] = None + del obj + + self.cache_id[i] = None + ctx = self.contexts[i] + if ctx is not None: + ctx.__exit__(None, None, None) + self.contexts[i] = None + + if self.shm is not None: + self.shm.close() + logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") def _notify_clients(self, msg_id: int) -> bool: @@ -252,38 +283,29 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: """get object from cache; if not in cache and shm provided -- get from shm""" - dispatched = False - buf_idx = msg_id % self.num_buffers - if self.cache_id[buf_idx] == msg_id: - dispatched = True + if self.cache_id[buf_idx] != msg_id: + raise CacheMiss + + try: yield self.cache[buf_idx] + finally: + self.backpressure.free(client_id, buf_idx) + if self.backpressure.buffers[buf_idx].is_empty: - elif self.shm is not None: - with self.shm.buffer(buf_idx, readonly=True) as mem: - if MessageMarshal.msg_id(mem) == msg_id: - with MessageMarshal.obj_from_mem(mem) as obj: - dispatched = True - yield obj - - if not dispatched: - tcp_mem = self.tcp_cache[buf_idx] - if MessageMarshal.msg_id(tcp_mem) == msg_id: - with MessageMarshal.obj_from_mem(tcp_mem) as obj: - dispatched = True - yield obj - else: - raise CacheMiss - - self.backpressure.free(client_id, buf_idx) - if self.backpressure.buffers[buf_idx].is_empty: - - # If pub is in same process as this channel, avoid TCP - if self._local_backpressure is not None: - self._local_backpressure.free(self.id, buf_idx) - return - - self._acknowledge(msg_id) + # If pub is in same process as this channel, avoid TCP + if self._local_backpressure is not None: + self._local_backpressure.free(self.id, buf_idx) + else: + self._acknowledge(msg_id) + + # Free cache and release contexts + self.cache[buf_idx] = None + self.cache_id[buf_idx] = None + ctx = self.contexts[buf_idx] + if ctx is not None: + ctx.__exit__(None, None, None) + self.contexts[buf_idx] = None def _acknowledge(self, msg_id: int) -> None: try: diff --git a/src/ezmsg/core/server.py b/src/ezmsg/core/server.py index 36fd03c6..5d5c8de6 100644 --- a/src/ezmsg/core/server.py +++ b/src/ezmsg/core/server.py @@ -29,8 +29,8 @@ class ThreadedAsyncServer(threading.Thread): _sock: socket.socket _loop: asyncio.AbstractEventLoop - def __init__(self) -> None: - super().__init__(daemon=True) + def __init__(self, **kwargs) -> None: + super().__init__(**{**kwargs, **dict(daemon=True)}) self._server_up = threading.Event() self._shutdown = threading.Event() diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index 52d2aae2..ad38ee4b 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -15,7 +15,7 @@ def __init__(self, msg_type: Any): def __repr__(self) -> str: _addr = self.address if self._location is not None else "unlocated" - return f"Stream:{_addr}[{self.msg_type}]" + return f"Stream:{_addr}[{self.msg_type.__name__}]" class InputStream(Stream): diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index d5a9f5b4..34be4fb6 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -144,11 +144,11 @@ async def _graph_connection( else: logger.warning( - f"Subscriber {self.id} rx unknown command from GraphServer: {cmd}" + f"Subscriber {self.topic}({self.id}) rx unknown command from GraphServer: {cmd}" ) except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Subscriber {self.id} lost connection to graph server") + logger.debug(f"Subscriber {self.topic}({self.id}) lost connection to graph server") finally: for pub_id in self._channels: From 7b7cdceed64b4b26b8fb253a6b674b1d87965bcd Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 11 Nov 2025 12:51:48 -0500 Subject: [PATCH 71/88] tweak to remove nonstandard handling of state vars --- src/ezmsg/util/perf/impl.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index a5eb18c2..51ae600b 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -67,15 +67,15 @@ class LoadTestSample: dynamic_data: np.ndarray key: str +class LoadTestSourceState(ez.State): + counter: int = 0 class LoadTestSource(ez.Unit): OUTPUT = ez.OutputStream(LoadTestSample) SETTINGS = LoadTestSettings + STATE = LoadTestSourceState async def initialize(self) -> None: - self.running = True - self.counter = 0 - self.OUTPUT.num_buffers = self.SETTINGS.buffers self.OUTPUT.force_tcp = self.SETTINGS.force_tcp @@ -84,8 +84,6 @@ async def publish(self) -> typing.AsyncGenerator: ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") start_time = time.time() for _ in range(self.SETTINGS.num_msgs): - if not self.running: - break current_time = time.time() if current_time - start_time >= self.SETTINGS.max_duration: @@ -95,21 +93,20 @@ async def publish(self) -> typing.AsyncGenerator: self.OUTPUT, LoadTestSample( _timestamp=time.time(), - counter=self.counter, + counter=self.STATE.counter, dynamic_data=np.zeros( int(self.SETTINGS.dynamic_size // 4), dtype=np.float32 ), key = self.name, ), ) - self.counter += 1 + self.STATE.counter += 1 ez.logger.info("Exiting publish") raise ez.Complete async def shutdown(self) -> None: - self.running = False - ez.logger.info(f"Samples sent: {self.counter}") + ez.logger.info(f"Samples sent: {self.STATE.counter}") From 746b49e8ba00645e80260b809a3b02fe04e0db1b Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 12 Nov 2025 12:07:03 -0500 Subject: [PATCH 72/88] implemented batch writes for stable perf testing with num_buffers > 1 --- src/ezmsg/core/backendprocess.py | 1 + src/ezmsg/core/pubclient.py | 45 ++++++++++++++++++++++---------- src/ezmsg/core/stream.py | 5 +++- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 9c535a34..3bd4dde4 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -170,6 +170,7 @@ async def setup_state(): buf_size=stream.buf_size, start_paused=True, force_tcp=stream.force_tcp, + batch_write=stream.batch_write, ), loop=loop, ).result() diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 78366547..12a5c119 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -5,7 +5,7 @@ from uuid import UUID from contextlib import suppress -from dataclasses import dataclass +from dataclasses import dataclass, field from .backpressure import Backpressure from .shm import SHMContext @@ -43,6 +43,7 @@ class PubChannelInfo(ChannelInfo): pid: int shm_ok: bool = False + batch: list[bytes] = field(default_factory = list) class Publisher: @@ -63,6 +64,7 @@ class Publisher: _msg_id: int _shm: SHMContext _force_tcp: bool + _batch_write: bool _last_backpressure_event: float _graph_address: AddressType | None @@ -82,6 +84,7 @@ async def create( num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, + batch_write: bool = False, ) -> "Publisher": graph_service = GraphService(graph_address) reader, writer = await graph_service.open_connection() @@ -98,7 +101,8 @@ async def create( graph_address = graph_address, num_buffers = num_buffers, start_paused = start_paused, - force_tcp = force_tcp + force_tcp = force_tcp, + batch_write = batch_write ) start_port = int( @@ -153,6 +157,7 @@ def __init__( num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, + batch_write: bool = False, ) -> None: """DO NOT USE this constructor to make a Publisher; use `create` instead""" self.id = id @@ -168,6 +173,7 @@ def __init__( self._num_buffers = num_buffers self._backpressure = Backpressure(num_buffers) self._force_tcp = force_tcp + self._batch_write = batch_write self._last_backpressure_event = -1 self._graph_address = graph_address @@ -321,33 +327,44 @@ async def broadcast(self, obj: Any) -> None: for channel in self._channels.values(): + msg: bytes = b'' + if self.pid == channel.pid and channel.shm_ok: continue # Local transmission handled by channel.put elif (not self._force_tcp) and self.pid != channel.pid and channel.shm_ok: - channel.writer.write( - Command.TX_SHM.value + \ - msg_id_bytes + \ + msg = ( + Command.TX_SHM.value + + msg_id_bytes + encode_str(self._shm.name) ) else: - channel.writer.write( - Command.TX_TCP.value + \ - msg_id_bytes + \ - total_size_bytes + \ - header + \ + msg = ( + Command.TX_TCP.value + + msg_id_bytes + + total_size_bytes + + header + b''.join([buffer for buffer in buffers]) ) - + try: - await channel.writer.drain() - self._backpressure.lease(channel.id, buf_idx) + if self._batch_write: + channel.batch.append(msg) + if len(channel.batch) == self._num_buffers: + channel.writer.write(b''.join(channel.batch)) + channel.batch.clear() + await channel.writer.drain() + for i in range(self._num_buffers): + self._backpressure.lease(channel.id, i) + else: + channel.writer.write(msg) + await channel.writer.drain() + self._backpressure.lease(channel.id, buf_idx) except (ConnectionResetError, BrokenPipeError): logger.debug( f"Publisher {self.id}: Channel {channel.id} connection fail" ) - continue self._msg_id += 1 diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index ad38ee4b..3b7e4702 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -37,6 +37,7 @@ class OutputStream(Stream): num_buffers: int buf_size: int force_tcp: bool + batch_write: bool def __init__( self, @@ -46,6 +47,7 @@ def __init__( num_buffers: int = 32, buf_size: int = DEFAULT_SHM_SIZE, force_tcp: bool = False, + batch_write: bool = False, ) -> None: super().__init__(msg_type) self.host = host @@ -53,7 +55,8 @@ def __init__( self.num_buffers = num_buffers self.buf_size = buf_size self.force_tcp = force_tcp + self.batch_write = batch_write def __repr__(self) -> str: preamble = f"Output{super().__repr__()}" - return f"{preamble}({self.num_buffers=}, {self.force_tcp=})" + return f"{preamble}({self.num_buffers=}, {self.force_tcp=}, {self.batch_write=})" From af8ff47687b0d56fb5e59eb0c385768584985fb0 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 14 Nov 2025 13:52:41 -0500 Subject: [PATCH 73/88] fix shm closure --- src/ezmsg/core/messagechannel.py | 235 +++++++++++++++++++++---------- src/ezmsg/core/messagemarshal.py | 32 ++--- src/ezmsg/core/pubclient.py | 11 +- src/ezmsg/core/shm.py | 21 ++- 4 files changed, 200 insertions(+), 99 deletions(-) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index e18e0060..ebca0615 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -4,8 +4,7 @@ import logging from uuid import UUID -from copy import deepcopy -from contextlib import contextmanager, suppress, ExitStack +from contextlib import contextmanager, suppress from .shm import SHMContext from .messagemarshal import MessageMarshal @@ -33,6 +32,124 @@ class CacheMiss(Exception): ... +class CacheEntry(typing.NamedTuple): + object: typing.Any + msg_id: int + context: typing.ContextManager | None + memory: memoryview | None + + +class MessageCache: + + _cache: list[CacheEntry | None] + + def __init__(self, num_buffers: int) -> None: + self._cache = [None] * num_buffers + + def _buf_idx(self, msg_id: int) -> int: + return msg_id % len(self._cache) + + def __getitem__(self, msg_id: int) -> typing.Any: + """ + Get a cached object by msg_id + + :param msg_id: Message ID to retreive from cache + :type msg_id: int + :raises CacheMiss: If this msg_id does not exist in the cache. + """ + entry = self._cache[self._buf_idx(msg_id)] + if entry is None or entry.msg_id != msg_id: + raise CacheMiss + return entry.object + + def keys(self) -> list[int]: + """ + Get a list of current cached msg_ids + """ + return [entry.msg_id for entry in self._cache if entry is not None] + + def put_local(self, obj: typing.Any, msg_id: int) -> None: + """ + Put an object with associated msg_id directly into cache + + :param obj: Object to put in cache. + :type obj: typing.Any + :param msg_id: ID associated with this message/object. + :type msg_id: int + """ + self._put( + CacheEntry( + object = obj, + msg_id = msg_id, + context = None, + memory = None, + ) + ) + + def put_from_mem(self, mem: memoryview) -> None: + """ + Reconstitute a message in mem and keep it in cache, releasing and + overwriting the existing slot in cache. + This method passes the lifecycle of the memoryview to the MessageCache + and the memoryview will be properly released by the cache with `free` + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + ctx = MessageMarshal.obj_from_mem(mem) + self._put( + CacheEntry( + object = ctx.__enter__(), + msg_id = MessageMarshal.msg_id(mem), + context = ctx, + memory = mem, + ) + ) + + def _put(self, entry: CacheEntry) -> None: + buf_idx = self._buf_idx(entry.msg_id) + self._release(buf_idx) + self._cache[buf_idx] = entry + + def _release(self, buf_idx: int) -> None: + entry = self._cache[buf_idx] + if entry is not None: + mem = entry.memory + ctx = entry.context + if ctx is not None: + ctx.__exit__(None, None, None) + del entry + self._cache[buf_idx] = None + if mem is not None: + mem.release() + + def release(self, msg_id: int) -> None: + """ + Release memory for the entry associated with msg_id + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + buf_idx = self._buf_idx(msg_id) + entry = self._cache[buf_idx] + if entry is None or entry.msg_id != msg_id: + raise CacheMiss + self._release(buf_idx) + + def clear(self) -> None: + """ + Release all cached objects + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + for i in range(len(self._cache)): + self._release(i) + + class Channel: """cache-backed message channel for a particular publisher""" @@ -42,9 +159,7 @@ class Channel: topic: str num_buffers: int - contexts: list[typing.ContextManager | None] - cache: list[typing.Any] - cache_id: list[int | None] + cache: MessageCache shm: SHMContext | None clients: dict[UUID, NotificationQueue | None] backpressure: Backpressure @@ -68,9 +183,7 @@ def __init__( self.num_buffers = num_buffers self.shm = shm - self.contexts = [None] * self.num_buffers - self.cache_id = [None] * self.num_buffers - self.cache = [None] * self.num_buffers + self.cache = MessageCache(self.num_buffers) self.backpressure = Backpressure(self.num_buffers) self.clients = dict() self._graph_address = graph_address @@ -108,6 +221,7 @@ async def create( shm = await graph_service.attach_shm(shm_name) writer.write(Command.SHM_OK.value) except (ValueError, OSError): + shm = None writer.write(Command.SHM_ATTACH_FAILED.value) writer.write(uint64_to_bytes(os.getpid())) @@ -137,14 +251,16 @@ async def create( return chan def close(self) -> None: - self._graph_task.cancel() self._pub_task.cancel() + self._graph_task.cancel() async def wait_closed(self) -> None: - with suppress(asyncio.CancelledError): - await self._graph_task with suppress(asyncio.CancelledError): await self._pub_task + with suppress(asyncio.CancelledError): + await self._graph_task + if self.shm is not None: + await self.shm.wait_closed() async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter @@ -176,23 +292,17 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: msg_id = await read_int(reader) buf_idx = msg_id % self.num_buffers - - ctx: typing.ContextManager | None = None - value: typing.Any = None if msg == Command.TX_SHM.value: shm_name = await read_str(reader) if self.shm is not None and self.shm.name != shm_name: - shm_entries = [] - for i, c in enumerate(self.contexts): - if isinstance(c, ExitStack): - shm_entries.append(i) - c.close() - + shm_entries = self.cache.keys() + self.cache.clear() self.shm.close() await self.shm.wait_closed() + try: self.shm = await GraphService(self._graph_address).attach_shm(shm_name) except ValueError: @@ -201,59 +311,43 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: ) raise - for i in shm_entries: - stack = ExitStack() - view = stack.enter_context(self.shm.buffer(i)) - assert MessageMarshal.msg_id(view) == self.cache_id[i] - self.cache[i] = stack.enter_context(MessageMarshal.obj_from_mem(view)) - self.contexts[i] = stack + for id in shm_entries: + self.cache.put_from_mem(self.shm[id % self.num_buffers]) assert self.shm is not None - - ctx = ExitStack() - view = ctx.enter_context(self.shm.buffer(buf_idx)) - assert MessageMarshal.msg_id(view) == msg_id - value = ctx.enter_context(MessageMarshal.obj_from_mem(view)) + assert MessageMarshal.msg_id(self.shm[buf_idx]) == msg_id + self.cache.put_from_mem(self.shm[buf_idx]) elif msg == Command.TX_TCP.value: buf_size = await read_int(reader) obj_bytes = await reader.readexactly(buf_size) - view = memoryview(obj_bytes).toreadonly() - ctx = MessageMarshal.obj_from_mem(view) - value = ctx.__enter__() + assert MessageMarshal.msg_id(obj_bytes) == msg_id + self.cache.put_from_mem(memoryview(obj_bytes).toreadonly()) else: raise ValueError(f"unimplemented data telemetry: {msg}") - if self._notify_clients(msg_id): - self.cache_id[buf_idx] = msg_id - self.cache[buf_idx] = value - self.contexts[buf_idx] = ctx - else: + if not self._notify_clients(msg_id): # Nobody is listening; need to ack! + self.cache.release(msg_id) self._acknowledge(msg_id) - ctx.__exit__(None, None, None) except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") finally: - await close_stream_writer(self._pub_writer) - - for i in range(self.num_buffers): - obj = self.cache[i] - self.cache[i] = None - del obj - - self.cache_id[i] = None - ctx = self.contexts[i] - if ctx is not None: - ctx.__exit__(None, None, None) - self.contexts[i] = None + # await close_stream_writer(self._pub_writer) + self._pub_writer.close() + + await self.backpressure.sync() + if self._local_backpressure is not None: + await self._local_backpressure.sync() + + self.cache.clear() if self.shm is not None: self.shm.close() - + logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") def _notify_clients(self, msg_id: int) -> bool: @@ -275,23 +369,28 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: buf_idx = msg_id % self.num_buffers if self._notify_clients(msg_id): - self.cache_id[buf_idx] = msg_id - self.cache[buf_idx] = msg + self.cache.put_local(msg, msg_id) self._local_backpressure.lease(self.id, buf_idx) @contextmanager def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: - """get object from cache; if not in cache and shm provided -- get from shm""" - - buf_idx = msg_id % self.num_buffers - if self.cache_id[buf_idx] != msg_id: - raise CacheMiss + """ + Get a message + + :param msg_id: Message ID to retreive + :type msg_id: int + :param client_id: UUID of client retreiving this message for backpressure purposes + :type client_id: UUID + :raises CacheMiss: If this msg_id does not exist in the cache. + """ try: - yield self.cache[buf_idx] + yield self.cache[msg_id] finally: + buf_idx = msg_id % self.num_buffers self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: + self.cache.release(msg_id) # If pub is in same process as this channel, avoid TCP if self._local_backpressure is not None: @@ -299,14 +398,6 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None else: self._acknowledge(msg_id) - # Free cache and release contexts - self.cache[buf_idx] = None - self.cache_id[buf_idx] = None - ctx = self.contexts[buf_idx] - if ctx is not None: - ctx.__exit__(None, None, None) - self.contexts[buf_idx] = None - def _acknowledge(self, msg_id: int) -> None: try: ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) @@ -342,10 +433,6 @@ def unregister_client(self, client_id: UUID) -> None: del self.clients[client_id] - def clear_cache(self): - self.cache_id = [None] * self.num_buffers - self.cache = [None] * self.num_buffers - def _ensure_address(address: AddressType | None) -> Address: if address is None: @@ -418,7 +505,7 @@ async def unregister( client_id: UUID, graph_address: AddressType | None = None ) -> None: - channel = await self.get(pub_id, graph_address) + channel = await self.get(pub_id, graph_address, create = False) channel.unregister_client(client_id) if len(channel.clients) == 0: diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index 5f7b119f..e9f571a0 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -84,25 +84,24 @@ def _write(cls, mem: memoryview, header: bytes, buffers: list[memoryview]): sidx += blen @classmethod - def _assert_initialized(cls, mem: memoryview) -> None: - if mem[:_PREAMBLE_LEN] != _PREAMBLE: + def _assert_initialized(cls, raw: memoryview | bytes) -> None: + if raw[:_PREAMBLE_LEN] != _PREAMBLE: raise UninitializedMemory @classmethod - def msg_id(cls, mem: memoryview) -> int | None: + def msg_id(cls, raw: memoryview | bytes) -> int: """ - Get msg_id currently written in mem; if uninitialized, return None. + Get msg_id from a buffer; if uninitialized, return None. - :param mem: Memory buffer to read from. - :type mem: memoryview - :return: Message ID if memory is initialized, None otherwise. - :rtype: int | None + :param mem: buffer to read from. + :type mem: memoryview | bytes + :return: Message ID of encoded message + :rtype: int + :raises UndersizedMemory: If buffer is not initialized. """ - try: - cls._assert_initialized(mem) - return bytes_to_uint(mem[_PREAMBLE_LEN : _PREAMBLE_LEN + UINT64_SIZE]) - except UninitializedMemory: - return None + cls._assert_initialized(raw) + return bytes_to_uint(raw[_PREAMBLE_LEN : _PREAMBLE_LEN + UINT64_SIZE]) + @classmethod @contextmanager @@ -214,11 +213,12 @@ def copy_obj(cls, from_mem: memoryview, to_mem: memoryview) -> None: :type from_mem: memoryview :param to_mem: Target memory buffer for copying. :type to_mem: memoryview + :raises UninitializedMemory: If from_mem buffer is not properly initialized. """ msg_id = cls.msg_id(from_mem) - if msg_id is not None: - with MessageMarshal.obj_from_mem(from_mem) as obj: - MessageMarshal.to_mem(msg_id, obj, to_mem) + with MessageMarshal.obj_from_mem(from_mem) as obj: + MessageMarshal.to_mem(msg_id, obj, to_mem) + # If some other byte-level representation is desired, you can just diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 03d15aca..8feccec1 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -11,7 +11,7 @@ from .shm import SHMContext from .graphserver import GraphService from .messagechannel import CHANNELS, Channel -from .messagemarshal import MessageMarshal +from .messagemarshal import MessageMarshal, UninitializedMemory from .netprotocol import ( Address, @@ -440,9 +440,12 @@ async def broadcast(self, obj: Any) -> None: new_shm = await GraphService(self._graph_address).create_shm(self._num_buffers, total_size * 2) for i in range(self._num_buffers): - with self._shm.buffer(i, readonly=True) as from_buf: - with new_shm.buffer(i) as to_buf: - MessageMarshal.copy_obj(from_buf, to_buf) + try: + with self._shm.buffer(i, readonly=True) as from_buf: + with new_shm.buffer(i) as to_buf: + MessageMarshal.copy_obj(from_buf, to_buf) + except UninitializedMemory: + pass self._shm.close() self._shm = new_shm diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 5d961b32..2ca55bca 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -125,8 +125,22 @@ async def _graph_connection( except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"SHMContext {self.name} GraphServer {type(e)}") finally: - await close_stream_writer(writer) self._shm.close() + await close_stream_writer(writer) + + def __getitem__(self, idx: int) -> memoryview: + """ + Get a memory view of a specific buffer in the shared memory segment. + + :param idx: Index of the buffer to access. + :type idx: int + :return: A memoryview of the buffer. + :rtype: memoryview + :raises BufferError: If shared memory is no longer accessible. + """ + if self._shm.buf is None: + raise BufferError(f"cannot access {self.name}: server disconnected") + return self._shm.buf[self._data_block_segs[idx]] @contextmanager def buffer( @@ -143,10 +157,7 @@ def buffer( :rtype: collections.abc.Generator[memoryview, None, None] :raises BufferError: If shared memory is no longer accessible. """ - if self._shm.buf is None: - raise BufferError(f"cannot access {self.name}: server disconnected") - - with self._shm.buf[self._data_block_segs[idx]] as mem: + with self[idx] as mem: if readonly: ro_mem = mem.toreadonly() yield ro_mem From 39d9cc570c862815ac4c91ceca5fd5f5b2500e8a Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Fri, 14 Nov 2025 13:54:45 -0500 Subject: [PATCH 74/88] revert batch write; eventually should implement proper flow control --- src/ezmsg/core/backendprocess.py | 1 - src/ezmsg/core/pubclient.py | 21 +++------------------ src/ezmsg/core/stream.py | 5 +---- src/ezmsg/util/perf/impl.py | 1 - 4 files changed, 4 insertions(+), 24 deletions(-) diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 307538d0..1eb67b16 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -219,7 +219,6 @@ async def setup_state(): buf_size=stream.buf_size, start_paused=True, force_tcp=stream.force_tcp, - batch_write=stream.batch_write, ), loop=loop, ).result() diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 8feccec1..c1b8288f 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -43,7 +43,6 @@ class PubChannelInfo(ChannelInfo): pid: int shm_ok: bool = False - batch: list[bytes] = field(default_factory = list) class Publisher: @@ -72,7 +71,6 @@ class Publisher: _msg_id: int _shm: SHMContext _force_tcp: bool - _batch_write: bool _last_backpressure_event: float _graph_address: AddressType | None @@ -99,7 +97,6 @@ async def create( num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, - batch_write: bool = False, ) -> "Publisher": """ Create a new Publisher instance and register it with the graph server. @@ -136,7 +133,6 @@ async def create( num_buffers = num_buffers, start_paused = start_paused, force_tcp = force_tcp, - batch_write = batch_write ) start_port = int( @@ -191,7 +187,6 @@ def __init__( num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, - batch_write: bool = False, ) -> None: """ Initialize a Publisher instance. @@ -223,7 +218,6 @@ def __init__( self._num_buffers = num_buffers self._backpressure = Backpressure(num_buffers) self._force_tcp = force_tcp - self._batch_write = batch_write self._last_backpressure_event = -1 self._graph_address = graph_address @@ -477,18 +471,9 @@ async def broadcast(self, obj: Any) -> None: ) try: - if self._batch_write: - channel.batch.append(msg) - if len(channel.batch) == self._num_buffers: - channel.writer.write(b''.join(channel.batch)) - channel.batch.clear() - await channel.writer.drain() - for i in range(self._num_buffers): - self._backpressure.lease(channel.id, i) - else: - channel.writer.write(msg) - await channel.writer.drain() - self._backpressure.lease(channel.id, buf_idx) + channel.writer.write(msg) + await channel.writer.drain() + self._backpressure.lease(channel.id, buf_idx) except (ConnectionResetError, BrokenPipeError): logger.debug( diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index 34dcedeb..178e746c 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -67,7 +67,6 @@ class OutputStream(Stream): num_buffers: int buf_size: int force_tcp: bool - batch_write: bool def __init__( self, @@ -77,7 +76,6 @@ def __init__( num_buffers: int = 32, buf_size: int = DEFAULT_SHM_SIZE, force_tcp: bool = False, - batch_write: bool = False, ) -> None: super().__init__(msg_type) self.host = host @@ -85,8 +83,7 @@ def __init__( self.num_buffers = num_buffers self.buf_size = buf_size self.force_tcp = force_tcp - self.batch_write = batch_write def __repr__(self) -> str: preamble = f"Output{super().__repr__()}" - return f"{preamble}({self.num_buffers=}, {self.force_tcp=}, {self.batch_write=})" + return f"{preamble}({self.num_buffers=}, {self.force_tcp=})" diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 1a01dc70..fc46b81e 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -78,7 +78,6 @@ class LoadTestSource(ez.Unit): async def initialize(self) -> None: self.OUTPUT.num_buffers = self.SETTINGS.buffers self.OUTPUT.force_tcp = self.SETTINGS.force_tcp - self.OUTPUT.batch_write = True @ez.publisher(OUTPUT) async def publish(self) -> typing.AsyncGenerator: From 603f8a44aef19dc29617084cbe304761b95dec60 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 17 Nov 2025 09:30:35 -0500 Subject: [PATCH 75/88] addressed deadlocks on shutdown and proper shm deallocation --- src/ezmsg/core/graphcontext.py | 2 +- src/ezmsg/core/messagechannel.py | 49 ++++++++++++-------------------- src/ezmsg/core/messagemarshal.py | 7 +++++ src/ezmsg/core/pubclient.py | 1 + src/ezmsg/core/shm.py | 19 ++++++------- 5 files changed, 36 insertions(+), 42 deletions(-) diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 9eefbe16..03ee5763 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -160,7 +160,7 @@ async def revert(self) -> None: """ Revert all changes made by this context. - This method closes all clients (publishers and subscribers) created by this + This method closes all publishers and subscribers created by this context and removes all edges that were added to the graph. It is automatically called when exiting the context manager. """ diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index ebca0615..35c664c5 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -336,18 +336,12 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: logger.debug(f"connection fail: channel:{self.id} - pub:{self.pub_id}") finally: - # await close_stream_writer(self._pub_writer) - self._pub_writer.close() - - await self.backpressure.sync() - if self._local_backpressure is not None: - await self._local_backpressure.sync() - self.cache.clear() - if self.shm is not None: self.shm.close() + await close_stream_writer(self._pub_writer) + logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") def _notify_clients(self, msg_id: int) -> bool: @@ -451,25 +445,7 @@ class ChannelManager: def __init__(self): default_address = Address.from_string(GRAPHSERVER_ADDR) self._registry = {default_address: dict()} - self._lock = asyncio.Lock() - async def get( - self, - pub_id: UUID, - graph_address: AddressType | None = None, - create: bool = False - ) -> Channel: - graph_address = _ensure_address(graph_address) - channel = self._registry.get(graph_address, dict()).get(pub_id, None) - if create and channel is None: - channel = await Channel.create(pub_id, graph_address) - channels = self._registry.get(graph_address, dict()) - channels[pub_id] = channel - self._registry[graph_address] = channels - if channel is None: - raise KeyError(f"channel {pub_id=} {graph_address=} does not exist") - return channel - async def register( self, pub_id: UUID, @@ -495,7 +471,14 @@ async def _register( graph_address: AddressType | None = None, local_backpressure: Backpressure | None = None ) -> Channel: - channel = await self.get(pub_id, graph_address, create = True) + graph_address = _ensure_address(graph_address) + try: + channel = self._registry.get(graph_address, dict())[pub_id] + except KeyError: + channel = await Channel.create(pub_id, graph_address) + channels = self._registry.get(graph_address, dict()) + channels[pub_id] = channel + self._registry[graph_address] = channels channel.register_client(client_id, queue, local_backpressure) return channel @@ -505,15 +488,19 @@ async def unregister( client_id: UUID, graph_address: AddressType | None = None ) -> None: - channel = await self.get(pub_id, graph_address, create = False) + graph_address = _ensure_address(graph_address) + channel = self._registry.get(graph_address, dict())[pub_id] channel.unregister_client(client_id) + logger.debug(f'unregistered {client_id} from {pub_id}; {len(channel.clients)} left') + if len(channel.clients) == 0: - channel.close() - await channel.wait_closed() - graph_address = _ensure_address(graph_address) registry = self._registry[graph_address] del registry[pub_id] + + channel.close() + await channel.wait_closed() + logger.debug(f'closed channel {pub_id}: no clients') diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index e9f571a0..31917117 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -140,6 +140,13 @@ def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: try: yield obj + + except GeneratorExit: + # This happens when we need to close shm/clear cache + # on shutdown and user code may have its fingers in + # memory that we need to free + pass + finally: del obj for buf in buffers: diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index c1b8288f..4f57a3a5 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -442,6 +442,7 @@ async def broadcast(self, obj: Any) -> None: pass self._shm.close() + await self._shm.wait_closed() self._shm = new_shm with self._shm.buffer(buf_idx) as mem: diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 2ca55bca..03597600 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -157,13 +157,11 @@ def buffer( :rtype: collections.abc.Generator[memoryview, None, None] :raises BufferError: If shared memory is no longer accessible. """ - with self[idx] as mem: - if readonly: - ro_mem = mem.toreadonly() - yield ro_mem - ro_mem.release() - else: - yield mem + mem = self[idx].toreadonly() if readonly else self[idx] + try: + yield mem + finally: + mem.release() def close(self) -> None: """ @@ -247,6 +245,8 @@ def lease( async def _wait_for_eof() -> None: try: await reader.read() + except (ConnectionResetError, BrokenPipeError): + pass finally: await close_stream_writer(writer) @@ -255,10 +255,9 @@ async def _wait_for_eof() -> None: self.leases.add(lease) return lease - def _release(self, task: "asyncio.Task[None]"): - self.leases.discard(task) - logger.debug(f"discarded lease from {self.shm.name}") + self.leases.remove(task) + logger.debug(f"removed lease from {self.shm.name}; {len(self.leases)} left") if len(self.leases) == 0: logger.debug(f"unlinking {self.shm.name}") self.shm.close() From c097b77315dc80c764545225c2af5f70855ff21c Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 17 Nov 2025 11:46:14 -0500 Subject: [PATCH 76/88] revert except clause --- src/ezmsg/core/messagemarshal.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index 31917117..e9f571a0 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -140,13 +140,6 @@ def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: try: yield obj - - except GeneratorExit: - # This happens when we need to close shm/clear cache - # on shutdown and user code may have its fingers in - # memory that we need to free - pass - finally: del obj for buf in buffers: From 256231cfa6b4a2df22e5cf42266c9506bc220e87 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 19 Nov 2025 09:47:38 -0500 Subject: [PATCH 77/88] fixes #204 --- src/ezmsg/core/graphserver.py | 154 +++++++++++++++++--- src/ezmsg/core/server.py | 256 ---------------------------------- src/ezmsg/core/shm.py | 4 +- 3 files changed, 134 insertions(+), 280 deletions(-) delete mode 100644 src/ezmsg/core/server.py diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 55036082..c6d9226f 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -1,6 +1,9 @@ import asyncio import logging import pickle +import os +import socket +import threading from contextlib import suppress from uuid import UUID, uuid1 @@ -24,15 +27,17 @@ GRAPHSERVER_ADDR_ENV, GRAPHSERVER_PORT_DEFAULT, DEFAULT_SHM_SIZE, + create_socket, + SERVER_PORT_START_ENV, + SERVER_PORT_START_DEFAULT, + DEFAULT_HOST, ) -from .server import ServiceManager, ThreadedAsyncServer from .shm import SHMContext, SHMInfo - logger = logging.getLogger("ezmsg") -class GraphServer(ThreadedAsyncServer): +class GraphServer(threading.Thread): """ Pub-sub directed acyclic graph (DAG) server. @@ -46,6 +51,11 @@ class GraphServer(ThreadedAsyncServer): The GraphServer is typically managed automatically by the ezmsg runtime and doesn't need to be instantiated directly by user code. """ + _server_up: threading.Event + _shutdown: threading.Event + + _sock: socket.socket + _loop: asyncio.AbstractEventLoop graph: DAG clients: dict[UUID, ClientInfo] @@ -54,23 +64,58 @@ class GraphServer(ThreadedAsyncServer): _client_tasks: dict[UUID, "asyncio.Task[None]"] _command_lock: asyncio.Lock - def __init__(self) -> None: - super().__init__(name = "GraphServer") + def __init__(self, **kwargs) -> None: + super().__init__(**{**kwargs, **dict(daemon=True, name=kwargs.get("name", "GraphServer"))}) + # threading events for lifecycle + self._server_up = threading.Event() + self._shutdown = threading.Event() + + # graph/server data self.graph = DAG() - self.clients = dict() - self._client_tasks = dict() - self.shms = dict() + self.clients = {} + self._client_tasks = {} + self.shms = {} + + @property + def address(self) -> Address: + return Address(*self._sock.getsockname()) + + def start(self, address: AddressType | None = None) -> None: # type: ignore[override] + if address is not None: + self._sock = create_socket(*address) + else: + start_port = int(os.environ.get(SERVER_PORT_START_ENV, SERVER_PORT_START_DEFAULT)) + self._sock = create_socket(start_port=start_port) + + self._loop = asyncio.new_event_loop() + super().start() + self._server_up.wait() + + def stop(self) -> None: + self._shutdown.set() + self.join() + + def run(self) -> None: + try: + asyncio.set_event_loop(self._loop) + with suppress(asyncio.CancelledError): + self._loop.run_until_complete(self._amain()) + finally: + self._loop.stop() + self._loop.close() - async def setup(self) -> None: + async def _setup(self) -> None: self._command_lock = asyncio.Lock() - async def shutdown(self) -> None: - for task in self._client_tasks.values(): + async def _shutdown_async(self) -> None: + # Cancel client handler tasks and wait for them to end. + for task in list(self._client_tasks.values()): task.cancel() - with suppress(asyncio.CancelledError): - await task + with suppress(asyncio.CancelledError): + await asyncio.gather(*self._client_tasks.values(), return_exceptions=True) self._client_tasks.clear() + # Cancel SHM leases for info in self.shms.values(): for lease_task in list(info.leases): lease_task.cancel() @@ -78,6 +123,30 @@ async def shutdown(self) -> None: await lease_task info.leases.clear() + async def _amain(self) -> None: + """ + Start the asyncio server and serve forever until shutdown is requested. + """ + await self._setup() + + server = await asyncio.start_server(self.api, sock=self._sock) + + async def monitor_shutdown() -> None: + # Thread event -> wake in event loop + await self._loop.run_in_executor(None, self._shutdown.wait) + server.close() + await self._shutdown_async() + await server.wait_closed() + + monitor_task = self._loop.create_task(monitor_shutdown()) + self._server_up.set() + + try: + await server.serve_forever() + finally: + self._shutdown.set() + await monitor_task + async def api( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: @@ -107,7 +176,6 @@ async def api( shm_info: SHMInfo | None = None if req == Command.SHM_CREATE.value: - # TODO: UUID num_buffers = await read_int(reader) buf_size = await read_int(reader) @@ -307,8 +375,10 @@ async def _handle_client( logger.debug(f"Client {client_id} disconnected from GraphServer: {e}") finally: + # Ensure any waiter on this client unblocks + # with suppress(Exception): self.clients[client_id].set_sync() - del self.clients[client_id] + self.clients.pop(client_id, None) await close_stream_writer(writer) async def _notify_subscriber(self, sub: SubscriberInfo) -> None: @@ -351,17 +421,56 @@ def _downstream_subs(self, topic: str) -> list[SubscriberInfo]: return [sub for sub in self._subscribers() if sub.topic in downstream_topics] -class GraphService(ServiceManager[GraphServer]): +class GraphService: ADDR_ENV = GRAPHSERVER_ADDR_ENV PORT_DEFAULT = GRAPHSERVER_PORT_DEFAULT + _address: Address | None + def __init__(self, address: AddressType | None = None) -> None: - super().__init__(GraphServer, address) + self._address = Address(*address) if address is not None else None + + @classmethod + def default_address(cls) -> Address: + address_str = os.environ.get(cls.ADDR_ENV, f"{DEFAULT_HOST}:{cls.PORT_DEFAULT}") + return Address.from_string(address_str) + + @property + def address(self) -> Address: + return self._address if self._address is not None else self.default_address() + + def create_server(self) -> GraphServer: + server = GraphServer(name="GraphServer") + server.start(self._address) + self._address = server.address + return server + + async def ensure(self) -> GraphServer | None: + """ + Try connecting to an existing server. If none is listening and no explicit + address/environment is set, start one and return it. If an existing one is + found, return None. + """ + server = None + ensure_server = False + if self._address is None: + # Only auto-start if env var not forcing a location + ensure_server = self.ADDR_ENV not in os.environ + + try: + reader, writer = await self.open_connection() + await close_stream_writer(writer) + except OSError as ref_e: + if not ensure_server: + raise ref_e + server = self.create_server() + + return server async def open_connection( self, ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: - reader, writer = await super().open_connection() + reader, writer = await asyncio.open_connection(*(self.address)) server_version = await read_str(reader) if server_version != __version__: logger.warning( @@ -377,6 +486,7 @@ async def connect(self, from_topic: str, to_topic: str) -> None: await writer.drain() response = await reader.read(1) if response == Command.CYCLIC.value: + await close_stream_writer(writer) raise CyclicException await close_stream_writer(writer) @@ -464,19 +574,18 @@ async def get_formatted_graph( async def create_shm( self, - # TODO: add UUID parameter - num_buffers: int, - buf_size: int = DEFAULT_SHM_SIZE + num_buffers: int, + buf_size: int = DEFAULT_SHM_SIZE, ) -> SHMContext: reader, writer = await self.open_connection() writer.write(Command.SHM_CREATE.value) - # TODO: serialize UUID writer.write(uint64_to_bytes(num_buffers)) writer.write(uint64_to_bytes(buf_size)) await writer.drain() response = await reader.read(1) if response != Command.COMPLETE.value: + await close_stream_writer(writer) raise ValueError("Error creating SHM segment") shm_name = await read_str(reader) @@ -490,6 +599,7 @@ async def attach_shm(self, name: str) -> SHMContext: response = await reader.read(1) if response != Command.COMPLETE.value: + await close_stream_writer(writer) raise ValueError("Invalid SHM Name") shm_name = await read_str(reader) diff --git a/src/ezmsg/core/server.py b/src/ezmsg/core/server.py deleted file mode 100644 index 89b2d584..00000000 --- a/src/ezmsg/core/server.py +++ /dev/null @@ -1,256 +0,0 @@ -import os -from collections.abc import Callable -import asyncio -import logging -import socket -import typing -import threading - -from contextlib import suppress - -from .netprotocol import ( - Address, - AddressType, - DEFAULT_HOST, - close_stream_writer, - create_socket, - SERVER_PORT_START_ENV, - SERVER_PORT_START_DEFAULT, -) - -logger = logging.getLogger("ezmsg") - -class ThreadedAsyncServer(threading.Thread): - """ - An asyncio server that runs in a dedicated loop in a separate thread. - - This class provides a foundation for running asyncio-based servers in their own - threads, allowing for concurrent operation with other parts of the application. - - :obj:`GraphServer` inherits from this class to implement specific server functionality. - """ - - _server_up: threading.Event - _shutdown: threading.Event - - _sock: socket.socket - _loop: asyncio.AbstractEventLoop - - def __init__(self, **kwargs) -> None: - """ - Initialize the threaded async server. - """ - super().__init__(**{**kwargs, **dict(daemon=True)}) - self._server_up = threading.Event() - self._shutdown = threading.Event() - - @property - def address(self) -> Address: - return Address(*self._sock.getsockname()) - - def start(self, address: AddressType | None = None) -> None: - if address is not None: - self._sock = create_socket(*address) - else: - start_port = int( - os.environ.get(SERVER_PORT_START_ENV, SERVER_PORT_START_DEFAULT) - ) - self._sock = create_socket(start_port=start_port) - - self._loop = asyncio.new_event_loop() - super().start() - self._server_up.wait() - - def stop(self) -> None: - self._shutdown.set() - self.join() - - def run(self) -> None: - """ - Run the async server in its own thread. - - This method starts a new asyncio event loop and runs the server's main - execution logic within that loop. - """ - try: - asyncio.set_event_loop(self._loop) - with suppress(asyncio.CancelledError): - self._loop.run_until_complete(self.amain()) - finally: - self._loop.stop() - self._loop.close() - - async def amain(self) -> None: - """ - Main asynchronous execution method for the server. - - This abstract method should be implemented by subclasses to define - the specific server behavior and async operations. - """ - await self.setup() - - server = await asyncio.start_server(self.api, sock=self._sock) - - async def monitor_shutdown() -> None: - await self._loop.run_in_executor(None, self._shutdown.wait) - server.close() - await server.wait_closed() - - monitor_task = self._loop.create_task(monitor_shutdown()) - - self._server_up.set() - - try: - await server.serve_forever() - - finally: - await self.shutdown() - self._shutdown.set() - await monitor_task - - async def setup(self) -> None: - """ - Setup method called before server starts serving. - - This method can be overridden by subclasses to perform any - initialization needed before the server begins accepting connections. - """ - ... - - async def shutdown(self) -> None: - """ - Shutdown method called when server is stopping. - - This method can be overridden by subclasses to perform any - cleanup needed when the server is shutting down. - """ - ... - - async def api( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - """ - API handler for client connections. - - This abstract method must be implemented by subclasses to handle - the server's communication protocol with connected clients. - - :param reader: Stream reader for receiving data from clients. - :type reader: asyncio.StreamReader - :param writer: Stream writer for sending data to clients. - :type writer: asyncio.StreamWriter - :raises NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError - - -T = typing.TypeVar("T", bound=ThreadedAsyncServer) - -# TODO: Get rid of Service Manager and drop this in its own file - -class ServiceManager(typing.Generic[T]): - """ - Manages the lifecycle and connection of threaded async servers. - - This generic class provides utilities for ensuring servers are running, - managing connections, and handling server creation with proper addressing. - - :type T: A ThreadedAsyncServer subclass type. - """ - _address: Address | None = None - _factory: Callable[[], T] - - ADDR_ENV: str - PORT_DEFAULT: int - - def __init__( - self, - factory: Callable[[], T], - address: AddressType | None = None, - ) -> None: - """ - Initialize the service manager. - - :param factory: Factory function that creates server instances. - :type factory: collections.abc.Callable[[], T] - :param address: Optional address tuple (host, port) for the server. - :type address: AddressType | None - """ - self._factory = factory - if address is not None: - self._address = Address(*address) - - async def ensure(self) -> T | None: - """ - Ensure a server is running and accessible. - - Attempts to connect to an existing server. If connection fails and - we should ensure a server exists, creates a new server instance. - - :return: The server instance if one was created, None if using existing. - :rtype: T | None - :raises OSError: If connection fails and no server should be created. - """ - server = None - ensure_server = False - if self._address is None: - ensure_server = self.ADDR_ENV not in os.environ - - try: - reader, writer = await self.open_connection() - await close_stream_writer(writer) - - except OSError as ref_e: - if not ensure_server: - raise ref_e - - server = self.create_server() - - return server - - @property - def address(self) -> Address: - """ - Get the server address. - - :return: The configured address or default address if none set. - :rtype: Address - """ - return self._address if self._address is not None else self.default_address() - - @classmethod - def default_address(cls) -> Address: - """ - Get the default server address from environment or class constants. - - :return: Address parsed from environment variable or default host:port. - :rtype: Address - """ - address_str = os.environ.get(cls.ADDR_ENV, f"{DEFAULT_HOST}:{cls.PORT_DEFAULT}") - return Address.from_string(address_str) - - def create_server(self) -> T: - """ - Create and start a new server instance. - - :return: The newly created and started server instance. - :rtype: T - """ - server = self._factory() - server.start(self._address) - self._address = server.address - return server - - async def open_connection( - self, - ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: - """ - Open a connection to the managed server. - - :return: Tuple of (reader, writer) for communicating with the server. - :rtype: tuple[asyncio.StreamReader, asyncio.StreamWriter] - :raises OSError: If connection cannot be established. - """ - return await asyncio.open_connection( - *(self.default_address() if self._address is None else self._address) - ) diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 03597600..4c5ad140 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -256,8 +256,8 @@ async def _wait_for_eof() -> None: return lease def _release(self, task: "asyncio.Task[None]"): - self.leases.remove(task) - logger.debug(f"removed lease from {self.shm.name}; {len(self.leases)} left") + self.leases.discard(task) + logger.debug(f"discarded lease from {self.shm.name}; {len(self.leases)} left") if len(self.leases) == 0: logger.debug(f"unlinking {self.shm.name}") self.shm.close() From 1e4581111d011688a2d3283575a17ca103540921 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 19 Nov 2025 10:08:11 -0500 Subject: [PATCH 78/88] refactor classes out of messagechannel --- src/ezmsg/core/channelmanager.py | 91 ++++++++++++++ src/ezmsg/core/messagecache.py | 124 +++++++++++++++++++ src/ezmsg/core/messagechannel.py | 203 +------------------------------ src/ezmsg/core/pubclient.py | 5 +- src/ezmsg/core/subclient.py | 4 +- 5 files changed, 221 insertions(+), 206 deletions(-) create mode 100644 src/ezmsg/core/channelmanager.py create mode 100644 src/ezmsg/core/messagecache.py diff --git a/src/ezmsg/core/channelmanager.py b/src/ezmsg/core/channelmanager.py new file mode 100644 index 00000000..03815f69 --- /dev/null +++ b/src/ezmsg/core/channelmanager.py @@ -0,0 +1,91 @@ +import logging + +from uuid import UUID + +from .messagechannel import Channel, NotificationQueue +from .backpressure import Backpressure +from .netprotocol import ( + Address, + AddressType, + GRAPHSERVER_ADDR +) + +logger = logging.getLogger("ezmsg") + +def _ensure_address(address: AddressType | None) -> Address: + if address is None: + return Address.from_string(GRAPHSERVER_ADDR) + + elif not isinstance(address, Address): + return Address(*address) + + return address + + +class ChannelManager: + + _registry: dict[Address, dict[UUID, Channel]] + + def __init__(self): + default_address = Address.from_string(GRAPHSERVER_ADDR) + self._registry = {default_address: dict()} + + async def register( + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue, + graph_address: AddressType | None = None, + ) -> Channel: + return await self._register(pub_id, client_id, queue, graph_address, None) + + async def register_local_pub( + self, + pub_id: UUID, + local_backpressure: Backpressure | None = None, + graph_address: AddressType | None = None, + ) -> Channel: + return await self._register(pub_id, pub_id, None, graph_address, local_backpressure) + + async def _register( + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue | None = None, + graph_address: AddressType | None = None, + local_backpressure: Backpressure | None = None + ) -> Channel: + graph_address = _ensure_address(graph_address) + try: + channel = self._registry.get(graph_address, dict())[pub_id] + except KeyError: + channel = await Channel.create(pub_id, graph_address) + channels = self._registry.get(graph_address, dict()) + channels[pub_id] = channel + self._registry[graph_address] = channels + channel.register_client(client_id, queue, local_backpressure) + return channel + + async def unregister( + self, + pub_id: UUID, + client_id: UUID, + graph_address: AddressType | None = None + ) -> None: + graph_address = _ensure_address(graph_address) + channel = self._registry.get(graph_address, dict())[pub_id] + channel.unregister_client(client_id) + + logger.debug(f'unregistered {client_id} from {pub_id}; {len(channel.clients)} left') + + if len(channel.clients) == 0: + registry = self._registry[graph_address] + del registry[pub_id] + + channel.close() + await channel.wait_closed() + + logger.debug(f'closed channel {pub_id}: no clients') + + +CHANNELS = ChannelManager() diff --git a/src/ezmsg/core/messagecache.py b/src/ezmsg/core/messagecache.py new file mode 100644 index 00000000..4ec1702b --- /dev/null +++ b/src/ezmsg/core/messagecache.py @@ -0,0 +1,124 @@ +import typing + +from .messagemarshal import MessageMarshal + + +class CacheMiss(Exception): ... + + +class CacheEntry(typing.NamedTuple): + object: typing.Any + msg_id: int + context: typing.ContextManager | None + memory: memoryview | None + + +class MessageCache: + + _cache: list[CacheEntry | None] + + def __init__(self, num_buffers: int) -> None: + self._cache = [None] * num_buffers + + def _buf_idx(self, msg_id: int) -> int: + return msg_id % len(self._cache) + + def __getitem__(self, msg_id: int) -> typing.Any: + """ + Get a cached object by msg_id + + :param msg_id: Message ID to retreive from cache + :type msg_id: int + :raises CacheMiss: If this msg_id does not exist in the cache. + """ + entry = self._cache[self._buf_idx(msg_id)] + if entry is None or entry.msg_id != msg_id: + raise CacheMiss + return entry.object + + def keys(self) -> list[int]: + """ + Get a list of current cached msg_ids + """ + return [entry.msg_id for entry in self._cache if entry is not None] + + def put_local(self, obj: typing.Any, msg_id: int) -> None: + """ + Put an object with associated msg_id directly into cache + + :param obj: Object to put in cache. + :type obj: typing.Any + :param msg_id: ID associated with this message/object. + :type msg_id: int + """ + self._put( + CacheEntry( + object = obj, + msg_id = msg_id, + context = None, + memory = None, + ) + ) + + def put_from_mem(self, mem: memoryview) -> None: + """ + Reconstitute a message in mem and keep it in cache, releasing and + overwriting the existing slot in cache. + This method passes the lifecycle of the memoryview to the MessageCache + and the memoryview will be properly released by the cache with `free` + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + ctx = MessageMarshal.obj_from_mem(mem) + self._put( + CacheEntry( + object = ctx.__enter__(), + msg_id = MessageMarshal.msg_id(mem), + context = ctx, + memory = mem, + ) + ) + + def _put(self, entry: CacheEntry) -> None: + buf_idx = self._buf_idx(entry.msg_id) + self._release(buf_idx) + self._cache[buf_idx] = entry + + def _release(self, buf_idx: int) -> None: + entry = self._cache[buf_idx] + if entry is not None: + mem = entry.memory + ctx = entry.context + if ctx is not None: + ctx.__exit__(None, None, None) + del entry + self._cache[buf_idx] = None + if mem is not None: + mem.release() + + def release(self, msg_id: int) -> None: + """ + Release memory for the entry associated with msg_id + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + buf_idx = self._buf_idx(msg_id) + entry = self._cache[buf_idx] + if entry is None or entry.msg_id != msg_id: + raise CacheMiss + self._release(buf_idx) + + def clear(self) -> None: + """ + Release all cached objects + + :param mem: Source memoryview containing serialized object. + :type from_mem: memoryview + :raises UninitializedMemory: If mem buffer is not properly initialized. + """ + for i in range(len(self._cache)): + self._release(i) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 35c664c5..fcd8f09c 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -9,7 +9,7 @@ from .shm import SHMContext from .messagemarshal import MessageMarshal from .backpressure import Backpressure - +from .messagecache import MessageCache from .graphserver import GraphService from .netprotocol import ( Command, @@ -20,7 +20,6 @@ uint64_to_bytes, encode_str, close_stream_writer, - GRAPHSERVER_ADDR ) logger = logging.getLogger("ezmsg") @@ -29,127 +28,6 @@ NotificationQueue = asyncio.Queue[typing.Tuple[UUID, int]] -class CacheMiss(Exception): ... - - -class CacheEntry(typing.NamedTuple): - object: typing.Any - msg_id: int - context: typing.ContextManager | None - memory: memoryview | None - - -class MessageCache: - - _cache: list[CacheEntry | None] - - def __init__(self, num_buffers: int) -> None: - self._cache = [None] * num_buffers - - def _buf_idx(self, msg_id: int) -> int: - return msg_id % len(self._cache) - - def __getitem__(self, msg_id: int) -> typing.Any: - """ - Get a cached object by msg_id - - :param msg_id: Message ID to retreive from cache - :type msg_id: int - :raises CacheMiss: If this msg_id does not exist in the cache. - """ - entry = self._cache[self._buf_idx(msg_id)] - if entry is None or entry.msg_id != msg_id: - raise CacheMiss - return entry.object - - def keys(self) -> list[int]: - """ - Get a list of current cached msg_ids - """ - return [entry.msg_id for entry in self._cache if entry is not None] - - def put_local(self, obj: typing.Any, msg_id: int) -> None: - """ - Put an object with associated msg_id directly into cache - - :param obj: Object to put in cache. - :type obj: typing.Any - :param msg_id: ID associated with this message/object. - :type msg_id: int - """ - self._put( - CacheEntry( - object = obj, - msg_id = msg_id, - context = None, - memory = None, - ) - ) - - def put_from_mem(self, mem: memoryview) -> None: - """ - Reconstitute a message in mem and keep it in cache, releasing and - overwriting the existing slot in cache. - This method passes the lifecycle of the memoryview to the MessageCache - and the memoryview will be properly released by the cache with `free` - - :param mem: Source memoryview containing serialized object. - :type from_mem: memoryview - :raises UninitializedMemory: If mem buffer is not properly initialized. - """ - ctx = MessageMarshal.obj_from_mem(mem) - self._put( - CacheEntry( - object = ctx.__enter__(), - msg_id = MessageMarshal.msg_id(mem), - context = ctx, - memory = mem, - ) - ) - - def _put(self, entry: CacheEntry) -> None: - buf_idx = self._buf_idx(entry.msg_id) - self._release(buf_idx) - self._cache[buf_idx] = entry - - def _release(self, buf_idx: int) -> None: - entry = self._cache[buf_idx] - if entry is not None: - mem = entry.memory - ctx = entry.context - if ctx is not None: - ctx.__exit__(None, None, None) - del entry - self._cache[buf_idx] = None - if mem is not None: - mem.release() - - def release(self, msg_id: int) -> None: - """ - Release memory for the entry associated with msg_id - - :param mem: Source memoryview containing serialized object. - :type from_mem: memoryview - :raises UninitializedMemory: If mem buffer is not properly initialized. - """ - buf_idx = self._buf_idx(msg_id) - entry = self._cache[buf_idx] - if entry is None or entry.msg_id != msg_id: - raise CacheMiss - self._release(buf_idx) - - def clear(self) -> None: - """ - Release all cached objects - - :param mem: Source memoryview containing serialized object. - :type from_mem: memoryview - :raises UninitializedMemory: If mem buffer is not properly initialized. - """ - for i in range(len(self._cache)): - self._release(i) - - class Channel: """cache-backed message channel for a particular publisher""" @@ -426,82 +304,3 @@ def unregister_client(self, client_id: UUID) -> None: self._local_backpressure = None del self.clients[client_id] - - -def _ensure_address(address: AddressType | None) -> Address: - if address is None: - return Address.from_string(GRAPHSERVER_ADDR) - - elif not isinstance(address, Address): - return Address(*address) - - return address - - -class ChannelManager: - - _registry: dict[Address, dict[UUID, Channel]] - - def __init__(self): - default_address = Address.from_string(GRAPHSERVER_ADDR) - self._registry = {default_address: dict()} - - async def register( - self, - pub_id: UUID, - client_id: UUID, - queue: NotificationQueue, - graph_address: AddressType | None = None, - ) -> Channel: - return await self._register(pub_id, client_id, queue, graph_address, None) - - async def register_local_pub( - self, - pub_id: UUID, - local_backpressure: Backpressure | None = None, - graph_address: AddressType | None = None, - ) -> Channel: - return await self._register(pub_id, pub_id, None, graph_address, local_backpressure) - - async def _register( - self, - pub_id: UUID, - client_id: UUID, - queue: NotificationQueue | None = None, - graph_address: AddressType | None = None, - local_backpressure: Backpressure | None = None - ) -> Channel: - graph_address = _ensure_address(graph_address) - try: - channel = self._registry.get(graph_address, dict())[pub_id] - except KeyError: - channel = await Channel.create(pub_id, graph_address) - channels = self._registry.get(graph_address, dict()) - channels[pub_id] = channel - self._registry[graph_address] = channels - channel.register_client(client_id, queue, local_backpressure) - return channel - - async def unregister( - self, - pub_id: UUID, - client_id: UUID, - graph_address: AddressType | None = None - ) -> None: - graph_address = _ensure_address(graph_address) - channel = self._registry.get(graph_address, dict())[pub_id] - channel.unregister_client(client_id) - - logger.debug(f'unregistered {client_id} from {pub_id}; {len(channel.clients)} left') - - if len(channel.clients) == 0: - registry = self._registry[graph_address] - del registry[pub_id] - - channel.close() - await channel.wait_closed() - - logger.debug(f'closed channel {pub_id}: no clients') - - -CHANNELS = ChannelManager() \ No newline at end of file diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 4f57a3a5..45e8a721 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -5,12 +5,13 @@ from uuid import UUID from contextlib import suppress -from dataclasses import dataclass, field +from dataclasses import dataclass from .backpressure import Backpressure from .shm import SHMContext from .graphserver import GraphService -from .messagechannel import CHANNELS, Channel +from .channelmanager import CHANNELS +from .messagechannel import Channel from .messagemarshal import MessageMarshal, UninitializedMemory from .netprotocol import ( diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index f8b2fb5e..ad645d3d 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -1,5 +1,4 @@ import asyncio -from collections.abc import AsyncGenerator import logging import typing @@ -8,7 +7,8 @@ from copy import deepcopy from .graphserver import GraphService -from .messagechannel import CHANNELS, NotificationQueue, Channel +from .channelmanager import CHANNELS +from .messagechannel import NotificationQueue, Channel from .netprotocol import ( AddressType, From 592233fb8bf27a24b896506b10d88751a4cd1c00 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Wed, 19 Nov 2025 11:30:28 -0500 Subject: [PATCH 79/88] Added docstrings for Channel and ChannelManager --- src/ezmsg/core/channelmanager.py | 45 +++++++++++++++++++++++- src/ezmsg/core/messagechannel.py | 59 +++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/ezmsg/core/channelmanager.py b/src/ezmsg/core/channelmanager.py index 03815f69..d5f66e0a 100644 --- a/src/ezmsg/core/channelmanager.py +++ b/src/ezmsg/core/channelmanager.py @@ -23,6 +23,10 @@ def _ensure_address(address: AddressType | None) -> Address: class ChannelManager: + """ + ChannelManager maintains a process-specific registry of Channels, tracks what clients + need access to what channels, and creates/deallocates Channels accordingly + """ _registry: dict[Address, dict[UUID, Channel]] @@ -37,14 +41,42 @@ async def register( queue: NotificationQueue, graph_address: AddressType | None = None, ) -> Channel: + """ + Acquire the channel associated with a particular publisher, creating it if necessary + + :param pub_id: The UUID associated with the publisher to acquire a channel for + :type pub_id: UUID + :param client_id: The UUID associated with the client interested in this channel, used for deallocation when the channel has no more registered clients + :type client_id: UUID + :param queue: An asyncio Queue for the interested client that will be populated with incoming message notifications + :type queue: asyncio.Queue[tuple[UUID, int]] + :param graph_address: The address to the GraphServer that the requested publisher is managed by + :type graph_address: AddressType | None + :return: A Channel for retreiving messages from the requested Publisher + :rtype: Channel + """ return await self._register(pub_id, client_id, queue, graph_address, None) async def register_local_pub( self, pub_id: UUID, - local_backpressure: Backpressure | None = None, + local_backpressure: Backpressure, graph_address: AddressType | None = None, ) -> Channel: + """ + Register/create a channel for a publisher that is local to (in the same process as) the publisher in question. + Because this message channel is local to the publisher, it will directly push messages into the channel and the channel will directly manage the publisher's Backpressure without any telemetry or serialization. + .. note:: Since only a Publisher should be registering a channel this way, it will not need access to the messages it is publishing, hence it will not provide a queue. + + :param pub_id: The UUID associated with the publisher to acquire a channel for + :type pub_id: UUID + :param local_backpressure: The Backpressure object associated with the Publisher + :type local_backpressure: Backpressure + :param graph_address: The address to the GraphServer that the requested publisher is managed by + :type graph_address: AddressType | None + :return: A Channel that the Publisher can push messages to locally + :rtype: Channel + """ return await self._register(pub_id, pub_id, None, graph_address, local_backpressure) async def _register( @@ -72,6 +104,17 @@ async def unregister( client_id: UUID, graph_address: AddressType | None = None ) -> None: + """ + Indicate to the ChannelManager that the client referred to by client_id no longer needs access to the Channel associated with the publisher referred to by pub_id. + If no clients need access to this channel, the channel will be closed and removed from the ChannelManager. + + :param pub_id: The UUID associated with the publisher to acquire a channel for + :type pub_id: UUID + :param client_id: The UUID associated with the client interested in this channel, used for deallocation when the channel has no more registered clients + :type client_id: UUID + :param graph_address: The address to the GraphServer that the requested publisher is managed by + :type graph_address: AddressType | None + """ graph_address = _ensure_address(graph_address) channel = self._registry.get(graph_address, dict())[pub_id] channel.unregister_client(client_id) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index fcd8f09c..28474a0f 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -29,7 +29,16 @@ class Channel: - """cache-backed message channel for a particular publisher""" + """ + Channel is a "middle-man" that receives messages from a particular Publisher, + maintains the message in a MessageCache, and pushes notifications to interested + Subscribers in this process. + + Channel primarily exists to reduce redundant message serialization and telemetry. + + .. note:: + The Channel constructor should not be called directly, instead use Channel.create(...) + """ id: UUID pub_id: UUID @@ -73,6 +82,18 @@ async def create( pub_id: UUID, graph_address: AddressType, ) -> "Channel": + """ + Create a channel for a particular Publisher managed by a GraphServer at graph_address + + :param pub_id: The Publisher's UUID on the GraphServer + :type pub_id: UUID + :param graph_address: The address the GraphServer is hosted on. + :type graph_address: AddressType + :return: a configured and connected Channel for messages from the Publisher + :rtype: Channel + + .. note:: This is typically called by ChannelManager as interested Subscribers register. + """ graph_service = GraphService(graph_address) graph_reader, graph_writer = await graph_service.open_connection() @@ -129,10 +150,16 @@ async def create( return chan def close(self) -> None: + """ + Mark the Channel for shutdown and resource deallocation + """ self._pub_task.cancel() self._graph_task.cancel() async def wait_closed(self) -> None: + """ + Wait until the Channel has properly shutdown and its resources have been deallocated. + """ with suppress(asyncio.CancelledError): await self._pub_task with suppress(asyncio.CancelledError): @@ -143,6 +170,9 @@ async def wait_closed(self) -> None: async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: + """ + The task that handles communication between the GraphServer and the Publisher. + """ try: while True: cmd = await reader.read(1) @@ -161,6 +191,9 @@ async def _graph_connection( await close_stream_writer(writer) async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: + """ + The task that handles communication between the Channel and the Publisher it receives messages from. + """ try: while True: msg = await reader.read(1) @@ -233,8 +266,8 @@ def _notify_clients(self, msg_id: int) -> bool: def put_local(self, msg_id: int, msg: typing.Any) -> None: """ - put an object into cache (should only be used by Publishers) - returns true if any clients were notified + Put a message DIRECTLY into cache and notify all clients. + .. note:: This command should ONLY be used by Publishers that are in the same process as this Channel. """ if self._local_backpressure is None: raise ValueError('cannot put_local without access to publisher backpressure (is publisher in same process?)') @@ -247,13 +280,15 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: @contextmanager def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: """ - Get a message + Get a message via a ContextManager :param msg_id: Message ID to retreive :type msg_id: int :param client_id: UUID of client retreiving this message for backpressure purposes :type client_id: UUID :raises CacheMiss: If this msg_id does not exist in the cache. + :return: A ContextManager for the message (type: Any) + :rtype: Generator[Any] """ try: @@ -283,11 +318,27 @@ def register_client( queue: NotificationQueue | None = None, local_backpressure: Backpressure | None = None, ) -> None: + """ + Register an interested client and provide a queue for incoming message notifications. + + :param client_id: The UUID of the subscribing client + :type client_id: UUID + :param queue: The notification queue for the subscribing client + :type queue: asyncio.Queue[tuple[UUID, int]] | None + :param local_backpressure: The backpressure object for the Publisher if it is in the same process + :type local_backpressure: Backpressure + """ self.clients[client_id] = queue if client_id == self.pub_id: self._local_backpressure = local_backpressure def unregister_client(self, client_id: UUID) -> None: + """ + Unregister a subscribed client + + :param client_id: The UUID of the subscribing client + :type client_id: UUID + """ queue = self.clients[client_id] # queue is only 'None' if this client is a local publisher From 5ca5711e7714042f23d1286abf946217d75318c2 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 20 Nov 2025 11:39:13 -0500 Subject: [PATCH 80/88] chore: ruff formatting --- docs/source/conf.py | 7 +- examples/ezmsg_configs.py | 2 +- examples/ezmsg_toy.py | 6 +- src/ezmsg/core/addressable.py | 15 +- src/ezmsg/core/backend.py | 22 ++- src/ezmsg/core/backendprocess.py | 40 +++-- src/ezmsg/core/backpressure.py | 36 ++-- src/ezmsg/core/channelmanager.py | 42 +++-- src/ezmsg/core/collection.py | 18 +- src/ezmsg/core/command.py | 7 +- src/ezmsg/core/component.py | 16 +- src/ezmsg/core/dag.py | 42 ++--- src/ezmsg/core/graph_util.py | 11 +- src/ezmsg/core/graphcontext.py | 27 ++- src/ezmsg/core/graphserver.py | 64 +++---- src/ezmsg/core/message.py | 10 +- src/ezmsg/core/messagecache.py | 33 ++-- src/ezmsg/core/messagechannel.py | 76 +++++---- src/ezmsg/core/messagemarshal.py | 40 ++--- src/ezmsg/core/netprotocol.py | 52 +++--- src/ezmsg/core/pubclient.py | 124 +++++++------- src/ezmsg/core/settings.py | 5 +- src/ezmsg/core/shm.py | 46 +++--- src/ezmsg/core/state.py | 5 +- src/ezmsg/core/stream.py | 6 +- src/ezmsg/core/subclient.py | 58 ++++--- src/ezmsg/core/test.py | 14 +- src/ezmsg/core/unit.py | 34 ++-- src/ezmsg/core/util.py | 5 +- src/ezmsg/util/__init__.py | 2 +- src/ezmsg/util/messagecodec.py | 28 ++-- src/ezmsg/util/messagegate.py | 4 +- src/ezmsg/util/messagelogger.py | 6 +- src/ezmsg/util/messages/axisarray.py | 130 +++++++-------- src/ezmsg/util/messages/chunker.py | 19 ++- src/ezmsg/util/messages/key.py | 27 +-- src/ezmsg/util/messages/modify.py | 17 +- src/ezmsg/util/messages/util.py | 4 +- src/ezmsg/util/perf/analysis.py | 185 +++++++++++++-------- src/ezmsg/util/perf/command.py | 6 +- src/ezmsg/util/perf/envinfo.py | 55 +++++-- src/ezmsg/util/perf/impl.py | 112 ++++++------- src/ezmsg/util/perf/run.py | 238 +++++++++++++++------------ src/ezmsg/util/perf/util.py | 37 ++++- src/ezmsg/util/rate.py | 6 +- tests/ez_test_utils.py | 5 +- tests/messages/test_axisarray.py | 12 +- tests/messages/test_key.py | 1 - tests/messages/test_modify.py | 1 - tests/test_graph.py | 15 +- tests/test_new_messaging.py | 61 +++---- tests/test_reconstitute.py | 84 ++++++---- 52 files changed, 1033 insertions(+), 885 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index d7a569fc..e6e8e389 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,6 @@ # Configuration file for the Sphinx documentation builder. import os -import sys # -- Project information -------------------------- @@ -75,7 +74,7 @@ html_static_path = ["_static"] # Timestamp is inserted at every page bottom in this strftime format. -html_last_updated_fmt = '%Y-%m-%d' +html_last_updated_fmt = "%Y-%m-%d" # -- Options for EPUB output -------------------------- epub_show_urls = "footnote" @@ -95,8 +94,10 @@ def linkcode_resolve(domain, info): else: return f"{code_url}src/{filename}.py" + # -- Options for graphviz ----------------------------- graphviz_output_format = "svg" + def setup(app): - app.add_css_file("custom.css") \ No newline at end of file + app.add_css_file("custom.css") diff --git a/examples/ezmsg_configs.py b/examples/ezmsg_configs.py index d60934f2..113295eb 100644 --- a/examples/ezmsg_configs.py +++ b/examples/ezmsg_configs.py @@ -239,5 +239,5 @@ def network(self) -> ez.NetworkDefinition: ] for system in test_systems: - ez.logger.info(f"Testing { system.__name__ }") + ez.logger.info(f"Testing {system.__name__}") ez.run(system()) diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index c2ea3217..a5c5c772 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -192,8 +192,8 @@ def process_components(self): ez.run( SYSTEM=system, - connections = [ + connections=[ # Make PING.OUTPUT available on a topic ezmsg_attach.py - ( system.PING.OUTPUT, 'GLOBAL_PING_TOPIC' ), - ] + (system.PING.OUTPUT, "GLOBAL_PING_TOPIC"), + ], ) diff --git a/src/ezmsg/core/addressable.py b/src/ezmsg/core/addressable.py index e74d2731..9dd15f82 100644 --- a/src/ezmsg/core/addressable.py +++ b/src/ezmsg/core/addressable.py @@ -4,17 +4,18 @@ class Addressable: """ Base class for objects that can be addressed within the ezmsg system. - - Addressable objects have a hierarchical address structure consisting of + + Addressable objects have a hierarchical address structure consisting of a location path and a name, similar to a filesystem path. """ + _name: str | None _location: list[str] | None def __init__(self) -> None: """ Initialize an Addressable object. - + The name and location are initially None and must be set before the object can be properly addressed. This is achieved through the ``_set_name()`` and ``_set_location()`` methods. @@ -38,7 +39,7 @@ def _set_location(self, location: list[str] | None = None): def name(self) -> str: """ Get the name of this addressable object. - + :return: The object's name :rtype: str :raises AssertionError: If name has not been set @@ -51,7 +52,7 @@ def name(self) -> str: def location(self) -> list[str]: """ Get the location path of this addressable object. - + :return: List of path components representing the object's location :rtype: list[str] :raises AssertionError: If location has not been set @@ -64,10 +65,10 @@ def location(self) -> list[str]: def address(self) -> str: """ Get the full address of this object. - + The address is constructed by joining the location path and name with forward slashes, similar to a filesystem path. - + :return: The full address string :rtype: str """ diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 90114756..09b35bd6 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -57,7 +57,6 @@ def create_processes( graph_address: AddressType | None, backend_process: type[BackendProcess] = DefaultBackendProcess, ) -> None: - self._processes = [ backend_process( process_units, @@ -145,7 +144,7 @@ def configure_collections(comp: Component): if not processes: return None - + return cls( processes, graph_connections, @@ -160,8 +159,8 @@ def run_system( ) -> None: """ Deprecated function for running a system (Collection). - - .. deprecated:: + + .. deprecated:: Use :func:`run` instead to run any component (unit, collection). :param system: The collection to run @@ -192,9 +191,9 @@ def run( This is the main entry point for running ezmsg applications. It sets up the execution environment, initializes components, and manages the message-passing - infrastructure. + infrastructure. - On initialization, ezmsg will call ``initialize()`` for each :obj:`Unit` and + On initialization, ezmsg will call ``initialize()`` for each :obj:`Unit` and ``configure()`` for each :obj:`Collection`, if defined. On initialization, ezmsg will create a directed acyclic graph using the contents of ``connections``. @@ -218,9 +217,9 @@ def run( :param components_kwargs: Additional components specified as keyword arguments :type components_kwargs: Component - .. note:: + .. note:: Since jupyter notebooks run in a single process, you must set `force_single_process=True`. - + .. note:: The old method :obj:`run_system` has been deprecated and uses ``run()`` instead. """ @@ -264,13 +263,12 @@ async def create_graph_context() -> GraphContext: address = graph_context.graph_address if address is None: address = GraphService.default_address() - logger.info(f'Connected to GraphServer @ {address}') + logger.info(f"Connected to GraphServer @ {address}") else: - logger.info(f'Spawned GraphServer @ {graph_context.graph_address}') + logger.info(f"Spawned GraphServer @ {graph_context.graph_address}") execution_context.create_processes( - graph_address=graph_context.graph_address, - backend_process=backend_process + graph_address=graph_context.graph_address, backend_process=backend_process ) async def cleanup_graph() -> None: diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 1eb67b16..ba93ead5 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -22,7 +22,6 @@ from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR from .graphcontext import GraphContext -from .graphserver import GraphService from .pubclient import Publisher from .subclient import Subscriber from .netprotocol import AddressType @@ -34,7 +33,7 @@ class Complete(Exception): """ A type of Exception raised by Unit methods, which signals to ezmsg that the function can be shut down gracefully. - + If all functions in all Units raise Complete, the entire pipeline will terminate execution. This exception is used to signal normal completion of processing tasks. @@ -49,7 +48,7 @@ class Complete(Exception): class NormalTermination(Exception): """ A type of Exception which signals to ezmsg that the pipeline can be shut down gracefully. - + This exception is used to indicate that the system should terminate normally, typically when all processing is complete or when a graceful shutdown is requested. @@ -63,11 +62,12 @@ class NormalTermination(Exception): class BackendProcess(Process): """ Abstract base class for backend processes that execute Units. - + BackendProcess manages the execution of Units in a separate process, handling initialization, coordination with other processes via barriers, and cleanup operations. """ + units: list[Unit] term_ev: EventType start_barrier: BarrierType @@ -84,7 +84,7 @@ def __init__( ) -> None: """ Initialize the backend process. - + :param units: List of Units to execute in this process. :type units: list[Unit] :param term_ev: Event for coordinated termination. @@ -106,11 +106,10 @@ def __init__( self.graph_address = graph_address self.task_finished_ev: threading.Event | None = None - def run(self) -> None: """ Main entry point for the process execution. - + Sets up the event loop and handles the main processing logic with proper exception handling for interrupts. """ @@ -125,10 +124,10 @@ def run(self) -> None: def process(self, loop: asyncio.AbstractEventLoop) -> None: """ Abstract method for implementing the main processing logic. - + Subclasses must implement this method to define how Units are executed within the event loop. - + :param loop: The asyncio event loop for this process. :type loop: asyncio.AbstractEventLoop :raises NotImplementedError: Must be implemented by subclasses. @@ -139,11 +138,12 @@ def process(self, loop: asyncio.AbstractEventLoop) -> None: class DefaultBackendProcess(BackendProcess): """ Default implementation of BackendProcess for executing Units. - + This class provides the standard execution model for ezmsg Units, handling publishers, subscribers, and the complete Unit lifecycle including initialization, execution, and shutdown. """ + pubs: dict[str, Publisher] def process(self, loop: asyncio.AbstractEventLoop) -> None: @@ -230,11 +230,9 @@ async def setup_state(): logger.debug("Waiting at start barrier!") self.start_barrier.wait() - threads = [ - loop.run_in_executor(None, thread_fn, unit) - for unit in self.units - for thread_fn in unit.threads.values() - ] + for unit in self.units: + for thread_fn in unit.threads.values(): + loop.run_in_executor(None, thread_fn, unit) for pub in self.pubs.values(): pub.resume() @@ -398,11 +396,11 @@ async def handle_subscriber( ): """ Handle incoming messages from a subscriber and distribute to callables. - + Continuously receives messages from the subscriber and calls all registered callables with each message. Removes callables that raise Complete or NormalTermination exceptions. - + :param sub: Subscriber to receive messages from. :type sub: Subscriber :param callables: Set of async callables to invoke with messages. @@ -430,10 +428,10 @@ async def handle_subscriber( def run_loop(loop: asyncio.AbstractEventLoop): """ Run an asyncio event loop in the current thread. - + Sets the event loop for the current thread and runs it forever until interrupted or stopped. - + :param loop: The asyncio event loop to run. :type loop: asyncio.AbstractEventLoop """ @@ -450,10 +448,10 @@ def new_threaded_event_loop( ) -> Generator[asyncio.AbstractEventLoop, None, None]: """ Create a new asyncio event loop running in a separate thread. - + Provides a context manager that yields an event loop running in its own thread, allowing async operations to be run from synchronous code. - + :param ev: Optional event to signal when the loop is ready. :type ev: threading.Event | None :return: Context manager yielding the event loop. diff --git a/src/ezmsg/core/backpressure.py b/src/ezmsg/core/backpressure.py index 6a91759e..56c0e7cf 100644 --- a/src/ezmsg/core/backpressure.py +++ b/src/ezmsg/core/backpressure.py @@ -8,17 +8,18 @@ class BufferLease: """ Manages leases for a single buffer in the backpressure system. - + A BufferLease tracks which clients have active leases on a buffer and provides synchronization when the buffer becomes empty. """ + leases: set[UUID] empty: asyncio.Event def __init__(self) -> None: """ Initialize a new BufferLease. - + The buffer starts empty with no active leases. """ self.leases = set() @@ -28,7 +29,7 @@ def __init__(self) -> None: def add(self, uuid: UUID) -> None: """ Add a lease for the specified client. - + :param uuid: Unique identifier for the client :type uuid: UUID """ @@ -38,7 +39,7 @@ def add(self, uuid: UUID) -> None: def remove(self, uuid: UUID) -> None: """ Remove a lease for the specified client. - + :param uuid: Unique identifier for the client :type uuid: UUID """ @@ -49,7 +50,7 @@ def remove(self, uuid: UUID) -> None: async def wait(self) -> Literal[True]: """ Wait until the buffer becomes empty (no active leases). - + :return: Always returns True when the buffer is empty :rtype: Literal[True] """ @@ -59,7 +60,7 @@ async def wait(self) -> Literal[True]: def is_empty(self) -> bool: """ Check if the buffer has no active leases. - + :return: True if no leases are active, False otherwise :rtype: bool """ @@ -69,14 +70,15 @@ def is_empty(self) -> bool: class Backpressure: """ Manages backpressure for multiple buffers in a message passing system. - + The Backpressure class coordinates access to multiple buffers, tracking which buffers are in use and providing synchronization mechanisms to manage flow control between publishers and subscribers. - + :param num_buffers: Number of buffers to manage :type num_buffers: int """ + buffers: list[BufferLease] empty: asyncio.Event pressure: int @@ -84,7 +86,7 @@ class Backpressure: def __init__(self, num_buffers: int) -> None: """ Initialize backpressure management for the specified number of buffers. - + :param num_buffers: Number of buffers to create and manage :type num_buffers: int """ @@ -97,7 +99,7 @@ def __init__(self, num_buffers: int) -> None: def is_empty(self) -> bool: """ Check if all buffers are empty (no active pressure). - + :return: True if no buffers have active leases, False otherwise :rtype: bool """ @@ -106,7 +108,7 @@ def is_empty(self) -> bool: def available(self, buf_idx: int) -> bool: """ Check if a specific buffer is available (has no active leases). - + :param buf_idx: Index of the buffer to check :type buf_idx: int :return: True if the buffer is available, False otherwise @@ -117,7 +119,7 @@ def available(self, buf_idx: int) -> bool: async def wait(self, buf_idx: int) -> None: """ Wait until a specific buffer becomes available. - + :param buf_idx: Index of the buffer to wait for :type buf_idx: int """ @@ -126,7 +128,7 @@ async def wait(self, buf_idx: int) -> None: def lease(self, uuid: UUID, buf_idx: int) -> None: """ Create a lease on a specific buffer for the given client. - + :param uuid: Unique identifier for the client :type uuid: UUID :param buf_idx: Index of the buffer to lease @@ -140,7 +142,7 @@ def lease(self, uuid: UUID, buf_idx: int) -> None: def _free(self, uuid: UUID, buf_idx: int) -> None: """ Internal method to free a lease on a specific buffer. - + :param uuid: Unique identifier for the client :type uuid: UUID :param buf_idx: Index of the buffer to free @@ -156,10 +158,10 @@ def _free(self, uuid: UUID, buf_idx: int) -> None: def free(self, uuid: UUID, buf_idx: int | None = None) -> None: """ Free leases for the given client. - + If buf_idx is specified, only frees the lease on that buffer. If buf_idx is None, frees leases on all buffers for this client. - + :param uuid: Unique identifier for the client :type uuid: UUID :param buf_idx: Optional buffer index to free, or None to free all @@ -177,7 +179,7 @@ def free(self, uuid: UUID, buf_idx: int | None = None) -> None: async def sync(self) -> Literal[True]: """ Wait until all buffers are empty (no backpressure). - + :return: Always returns True when all buffers are empty :rtype: Literal[True] """ diff --git a/src/ezmsg/core/channelmanager.py b/src/ezmsg/core/channelmanager.py index d5f66e0a..42231d06 100644 --- a/src/ezmsg/core/channelmanager.py +++ b/src/ezmsg/core/channelmanager.py @@ -4,21 +4,18 @@ from .messagechannel import Channel, NotificationQueue from .backpressure import Backpressure -from .netprotocol import ( - Address, - AddressType, - GRAPHSERVER_ADDR -) +from .netprotocol import Address, AddressType, GRAPHSERVER_ADDR logger = logging.getLogger("ezmsg") + def _ensure_address(address: AddressType | None) -> Address: if address is None: return Address.from_string(GRAPHSERVER_ADDR) elif not isinstance(address, Address): return Address(*address) - + return address @@ -27,7 +24,7 @@ class ChannelManager: ChannelManager maintains a process-specific registry of Channels, tracks what clients need access to what channels, and creates/deallocates Channels accordingly """ - + _registry: dict[Address, dict[UUID, Channel]] def __init__(self): @@ -43,7 +40,7 @@ async def register( ) -> Channel: """ Acquire the channel associated with a particular publisher, creating it if necessary - + :param pub_id: The UUID associated with the publisher to acquire a channel for :type pub_id: UUID :param client_id: The UUID associated with the client interested in this channel, used for deallocation when the channel has no more registered clients @@ -67,7 +64,7 @@ async def register_local_pub( Register/create a channel for a publisher that is local to (in the same process as) the publisher in question. Because this message channel is local to the publisher, it will directly push messages into the channel and the channel will directly manage the publisher's Backpressure without any telemetry or serialization. .. note:: Since only a Publisher should be registering a channel this way, it will not need access to the messages it is publishing, hence it will not provide a queue. - + :param pub_id: The UUID associated with the publisher to acquire a channel for :type pub_id: UUID :param local_backpressure: The Backpressure object associated with the Publisher @@ -77,15 +74,17 @@ async def register_local_pub( :return: A Channel that the Publisher can push messages to locally :rtype: Channel """ - return await self._register(pub_id, pub_id, None, graph_address, local_backpressure) + return await self._register( + pub_id, pub_id, None, graph_address, local_backpressure + ) async def _register( - self, - pub_id: UUID, - client_id: UUID, - queue: NotificationQueue | None = None, + self, + pub_id: UUID, + client_id: UUID, + queue: NotificationQueue | None = None, graph_address: AddressType | None = None, - local_backpressure: Backpressure | None = None + local_backpressure: Backpressure | None = None, ) -> Channel: graph_address = _ensure_address(graph_address) try: @@ -99,15 +98,12 @@ async def _register( return channel async def unregister( - self, - pub_id: UUID, - client_id: UUID, - graph_address: AddressType | None = None + self, pub_id: UUID, client_id: UUID, graph_address: AddressType | None = None ) -> None: """ Indicate to the ChannelManager that the client referred to by client_id no longer needs access to the Channel associated with the publisher referred to by pub_id. If no clients need access to this channel, the channel will be closed and removed from the ChannelManager. - + :param pub_id: The UUID associated with the publisher to acquire a channel for :type pub_id: UUID :param client_id: The UUID associated with the client interested in this channel, used for deallocation when the channel has no more registered clients @@ -119,7 +115,9 @@ async def unregister( channel = self._registry.get(graph_address, dict())[pub_id] channel.unregister_client(client_id) - logger.debug(f'unregistered {client_id} from {pub_id}; {len(channel.clients)} left') + logger.debug( + f"unregistered {client_id} from {pub_id}; {len(channel.clients)} left" + ) if len(channel.clients) == 0: registry = self._registry[graph_address] @@ -128,7 +126,7 @@ async def unregister( channel.close() await channel.wait_closed() - logger.debug(f'closed channel {pub_id}: no clients') + logger.debug(f"closed channel {pub_id}: no clients") CHANNELS = ChannelManager() diff --git a/src/ezmsg/core/collection.py b/src/ezmsg/core/collection.py index 4ac6e8f5..dd8f665d 100644 --- a/src/ezmsg/core/collection.py +++ b/src/ezmsg/core/collection.py @@ -9,9 +9,7 @@ # Iterable of (output_stream, input_stream) pairs defining the network connections -NetworkDefinition = Iterable[ - tuple[Stream | str, Stream | str] -] +NetworkDefinition = Iterable[tuple[Stream | str, Stream | str]] class CollectionMeta(ComponentMeta): @@ -41,7 +39,7 @@ def __init__( class Collection(Component, metaclass=CollectionMeta): """ Connects :obj:`Units ` together by defining a graph which connects OutputStreams to InputStreams. - + Collections are composite components that contain and coordinate multiple Units, defining how they communicate through stream connections. @@ -59,8 +57,8 @@ def __init__(self, *args, settings: Settings | None = None, **kwargs): def configure(self) -> None: """ A lifecycle hook that runs when the Collection is instantiated. - - This is the best place to call ``Unit.apply_settings()`` on each member + + This is the best place to call ``Unit.apply_settings()`` on each member Unit of the Collection. Override this method to perform collection-specific configuration of child components. """ @@ -68,9 +66,9 @@ def configure(self) -> None: def network(self) -> NetworkDefinition: """ - Override this method and have the definition return a NetworkDefinition + Override this method and have the definition return a NetworkDefinition which defines how InputStreams and OutputStreams from member Units will be connected. - + The NetworkDefinition specifies the message routing between components by connecting output streams to input streams. @@ -81,10 +79,10 @@ def network(self) -> NetworkDefinition: def process_components(self) -> AbstractCollection[Component]: """ - Override this method and have the definition return a tuple which contains + Override this method and have the definition return a tuple which contains Units and Collections which should run in their own processes. - This method allows you to specify which components should be isolated + This method allows you to specify which components should be isolated in separate processes for performance or isolation requirements. :return: Collection of components that should run in separate processes diff --git a/src/ezmsg/core/command.py b/src/ezmsg/core/command.py index 6f9b766b..09ce2dc3 100644 --- a/src/ezmsg/core/command.py +++ b/src/ezmsg/core/command.py @@ -5,7 +5,6 @@ import logging import subprocess import sys -import typing import webbrowser import zlib @@ -25,7 +24,7 @@ def cmdline() -> None: """ Command-line interface for ezmsg core server management. - + Provides commands for starting, stopping, and managing ezmsg server processes including GraphServer and SHMServer, as well as utilities for graph visualization. @@ -107,7 +106,7 @@ async def run_command( ) -> None: """ Run an ezmsg command with the specified parameters. - + This function handles various ezmsg commands like 'serve', 'start', 'shutdown', etc. and manages the graph and shared memory services. @@ -184,7 +183,7 @@ async def run_command( def mm(graph: str, target="live") -> str: """ Generate a Mermaid visualization URL for the given graph. - + :param graph: Graph representation string to visualize. :type graph: str :param target: Target platform ('live' or 'ink'). diff --git a/src/ezmsg/core/component.py b/src/ezmsg/core/component.py index d9e42902..fd831b93 100644 --- a/src/ezmsg/core/component.py +++ b/src/ezmsg/core/component.py @@ -98,16 +98,16 @@ def __init__( class Component(Addressable, metaclass=ComponentMeta): """ Metaclass which :obj:`Unit` and :obj:`Collection` inherit from. - + The Component class provides the foundation for all components in the ezmsg framework, including Units and Collections. It manages settings, state, streams, and provides the basic infrastructure for message-passing components. - + :param settings: Optional settings object for component configuration :type settings: Settings | None .. note:: - + When creating ezmsg nodes, inherit directly from :obj:`Unit` or :obj:`Collection`. """ @@ -193,7 +193,7 @@ def _set_location(self, location: list[str] | None = None): def tasks(self) -> dict[str, Callable]: """ Get the dictionary of tasks for this component. - + :return: Dictionary mapping task names to their callable functions :rtype: dict[str, collections.abc.Callable] """ @@ -203,7 +203,7 @@ def tasks(self) -> dict[str, Callable]: def streams(self) -> dict[str, Stream]: """ Get the dictionary of streams for this component. - + :return: Dictionary mapping stream names to their Stream objects :rtype: dict[str, Stream] """ @@ -213,7 +213,7 @@ def streams(self) -> dict[str, Stream]: def components(self) -> dict[str, "Component"]: """ Get the dictionary of child components for this component. - + :return: Dictionary mapping component names to their Component objects :rtype: dict[str, Component] """ @@ -223,7 +223,7 @@ def components(self) -> dict[str, "Component"]: def main(self) -> Callable[..., None] | None: """ Get the main function for this component. - + :return: The main callable function, or None if not set :rtype: collections.abc.Callable[..., None] | None """ @@ -233,7 +233,7 @@ def main(self) -> Callable[..., None] | None: def threads(self) -> dict[str, Callable]: """ Get the dictionary of thread functions for this component. - + :return: Dictionary mapping thread names to their callable functions :rtype: dict[str, collections.abc.Callable] """ diff --git a/src/ezmsg/core/dag.py b/src/ezmsg/core/dag.py index 285958c4..d0c9a72a 100644 --- a/src/ezmsg/core/dag.py +++ b/src/ezmsg/core/dag.py @@ -3,13 +3,14 @@ from dataclasses import dataclass, field -class CyclicException(Exception): +class CyclicException(Exception): """ Exception raised when an operation would create a cycle in the DAG. - + This exception is raised when attempting to add an edge that would violate the acyclic property of the directed acyclic graph. """ + ... @@ -20,18 +21,19 @@ class CyclicException(Exception): class DAG: """ Directed Acyclic Graph implementation for managing dependencies and connections. - + The DAG class provides functionality to build and maintain a directed acyclic graph, which is used by ezmsg to manage message flow between components while ensuring no circular dependencies exist. """ + graph: GraphType = field(default_factory=lambda: defaultdict(set), init=False) @property def nodes(self) -> set[str]: """ Get all nodes in the graph. - + :return: Set of all node names in the graph :rtype: set[str] """ @@ -41,13 +43,13 @@ def nodes(self) -> set[str]: def invgraph(self) -> GraphType: """ Get the inverse (reversed) graph. - + Creates a graph where all edges are reversed. This is useful for finding upstream dependencies. - + :return: Inverted graph representation :rtype: GraphType - + .. note:: This is currently implemented inefficiently but is adequate for typical use cases. """ @@ -61,13 +63,13 @@ def invgraph(self) -> GraphType: def add_edge(self, from_node: str, to_node: str) -> None: """ Ensure an edge exists in the graph. - + Adds an edge from from_node to to_node. Does nothing if the edge already exists. If the edge would make the graph cyclic, raises CyclicException. - + :param from_node: Source node name :type from_node: str - :param to_node: Destination node name + :param to_node: Destination node name :type to_node: str :raises CyclicException: If adding the edge would create a cycle """ @@ -88,10 +90,10 @@ def add_edge(self, from_node: str, to_node: str) -> None: def remove_edge(self, from_node: str, to_node: str) -> None: """ Ensure an edge is not present in the graph. - + Removes an edge from from_node to to_node. Does nothing if the edge doesn't exist. Automatically prunes unconnected nodes after removal. - + :param from_node: Source node name :type from_node: str :param to_node: Destination node name @@ -103,9 +105,9 @@ def remove_edge(self, from_node: str, to_node: str) -> None: def downstream(self, from_node: str) -> list[str]: """ Get a list of downstream nodes (including from_node). - + Performs a breadth-first search to find all nodes reachable from the given node. - + :param from_node: Starting node name :type from_node: str :return: List of downstream node names including the starting node @@ -116,10 +118,10 @@ def downstream(self, from_node: str) -> list[str]: def upstream(self, from_node: str) -> list[str]: """ Get a list of upstream nodes (including from_node). - + Performs a breadth-first search on the inverted graph to find all nodes that can reach the given node. - + :param from_node: Starting node name :type from_node: str :return: List of upstream node names including the starting node @@ -139,9 +141,9 @@ def _prune(self) -> None: def _leaves(graph: GraphType) -> set[str]: """ Find leaf nodes in a graph. - + Leaf nodes are nodes that have no outgoing edges. - + :param graph: The graph to analyze :type graph: GraphType :return: Set of leaf node names @@ -153,9 +155,9 @@ def _leaves(graph: GraphType) -> set[str]: def _bfs(graph: GraphType, node: str) -> list[str]: """ Breadth-first search of a graph starting from a given node. - + Traverses the graph in breadth-first order to find all reachable nodes. - + :param graph: The graph to search :type graph: GraphType :param node: Starting node for the search diff --git a/src/ezmsg/core/graph_util.py b/src/ezmsg/core/graph_util.py index 832b44ef..f6807bc5 100644 --- a/src/ezmsg/core/graph_util.py +++ b/src/ezmsg/core/graph_util.py @@ -1,6 +1,5 @@ from collections import defaultdict from textwrap import indent -from collections import defaultdict from uuid import uuid4 @@ -13,11 +12,11 @@ def prune_graph_connections( ) -> tuple[GraphType | None, list[str] | None]: """ Remove nodes from the graph that are proxy_topics. - + Proxy topics are nodes that are both source and target nodes in the connections graph. This function removes them and connects their upstream sources directly to their downstream targets. - + :param graph_connections: Graph representing topic connections. :type graph_connections: GraphType :return: Tuple of (pruned_graph, proxy_topics_list). @@ -53,7 +52,7 @@ def _pipeline_levels( ) -> defaultdict: """ Compute hierarchy levels for pipeline components. - + In ezmsg, a pipeline is built with units/collections, with subcomponents that are either more units/collection or input/output streams. The graph of the connections are stored in a DAG (directed acyclic graph object), but the @@ -63,7 +62,7 @@ def _pipeline_levels( This function computes the level of each component in the hierarchy (not just the connection nodes) and returns a dictionary of the level of each pipeline parts, where 0 is the level of the connection (leaf) nodes. - + :param graph_connections: Graph representing component connections. :type graph_connections: GraphType :param level_separator: Character separating hierarchy levels. @@ -99,7 +98,7 @@ def _get_parent_node(node: str, level_separator: str = "/") -> str: class LeafNodeException(Exception): """ Raised when connection nodes are not leaf nodes in the pipeline hierarchy. - + This exception indicates that the graph contains connections at non-leaf levels of the component hierarchy, which violates expected structure. """ diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 03ee5763..544f4663 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -2,7 +2,7 @@ import logging import typing -from .netprotocol import AddressType, Address +from .netprotocol import AddressType from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber @@ -15,12 +15,12 @@ class GraphContext: """ GraphContext maintains a list of created publishers, subscribers, and connections in the graph. - + The GraphContext provides a managed environment for creating and tracking publishers, subscribers, and graph connections. When the context is no longer needed, it can revert changes in the graph which disconnects publishers and removes modifications that this context made. - + It also maintains a context manager that ensures the GraphServer is running. :param graph_service: Optional graph service instance to use @@ -40,7 +40,6 @@ class GraphContext: def __init__( self, graph_address: AddressType | None = None, - ) -> None: self._clients = set() self._edges = set() @@ -57,16 +56,14 @@ def graph_address(self) -> AddressType | None: async def publisher(self, topic: str, **kwargs) -> Publisher: """ Create a publisher for the specified topic. - + :param topic: The topic name to publish to :type topic: str :param kwargs: Additional keyword arguments for publisher configuration :return: A Publisher instance for the topic :rtype: Publisher """ - pub = await Publisher.create( - topic, self.graph_address, **kwargs - ) + pub = await Publisher.create(topic, self.graph_address, **kwargs) self._clients.add(pub) return pub @@ -74,16 +71,14 @@ async def publisher(self, topic: str, **kwargs) -> Publisher: async def subscriber(self, topic: str, **kwargs) -> Subscriber: """ Create a subscriber for the specified topic. - + :param topic: The topic name to subscribe to :type topic: str :param kwargs: Additional keyword arguments for subscriber configuration :return: A Subscriber instance for the topic :rtype: Subscriber """ - sub = await Subscriber.create( - topic, self.graph_address, **kwargs - ) + sub = await Subscriber.create(topic, self.graph_address, **kwargs) self._clients.add(sub) return sub @@ -91,7 +86,7 @@ async def subscriber(self, topic: str, **kwargs) -> Subscriber: async def connect(self, from_topic: str, to_topic: str) -> None: """ Connect two topics in the message graph. - + :param from_topic: The source topic name :type from_topic: str :param to_topic: The destination topic name @@ -104,7 +99,7 @@ async def connect(self, from_topic: str, to_topic: str) -> None: async def disconnect(self, from_topic: str, to_topic: str) -> None: """ Disconnect two topics in the message graph. - + :param from_topic: The source topic name :type from_topic: str :param to_topic: The destination topic name @@ -116,7 +111,7 @@ async def disconnect(self, from_topic: str, to_topic: str) -> None: async def sync(self, timeout: typing.Optional[float] = None) -> None: """ Synchronize with the graph server. - + :param timeout: Optional timeout for the sync operation :type timeout: float | None """ @@ -159,7 +154,7 @@ async def __aexit__( async def revert(self) -> None: """ Revert all changes made by this context. - + This method closes all publishers and subscribers created by this context and removes all edges that were added to the graph. It is automatically called when exiting the context manager. diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index c6d9226f..cb97b9ec 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -40,7 +40,7 @@ class GraphServer(threading.Thread): """ Pub-sub directed acyclic graph (DAG) server. - + The GraphServer manages the message routing graph for ezmsg applications, handling publisher-subscriber relationships and maintaining the DAG structure. @@ -51,6 +51,7 @@ class GraphServer(threading.Thread): The GraphServer is typically managed automatically by the ezmsg runtime and doesn't need to be instantiated directly by user code. """ + _server_up: threading.Event _shutdown: threading.Event @@ -65,7 +66,9 @@ class GraphServer(threading.Thread): _command_lock: asyncio.Lock def __init__(self, **kwargs) -> None: - super().__init__(**{**kwargs, **dict(daemon=True, name=kwargs.get("name", "GraphServer"))}) + super().__init__( + **{**kwargs, **dict(daemon=True, name=kwargs.get("name", "GraphServer"))} + ) # threading events for lifecycle self._server_up = threading.Event() self._shutdown = threading.Event() @@ -84,7 +87,9 @@ def start(self, address: AddressType | None = None) -> None: # type: ignore[ove if address is not None: self._sock = create_socket(*address) else: - start_port = int(os.environ.get(SERVER_PORT_START_ENV, SERVER_PORT_START_DEFAULT)) + start_port = int( + os.environ.get(SERVER_PORT_START_ENV, SERVER_PORT_START_DEFAULT) + ) self._sock = create_socket(start_port=start_port) self._loop = asyncio.new_event_loop() @@ -159,7 +164,7 @@ async def api( await writer.drain() req = await reader.read(1) - + if not req: # Empty bytes object means EOF; Client disconnected # This happens frequently when future clients are just pinging @@ -179,7 +184,7 @@ async def api( num_buffers = await read_int(reader) buf_size = await read_int(reader) - # Create segment + # Create segment shm_info = SHMInfo.create(num_buffers, buf_size) self.shms[shm_info.shm.name] = shm_info logger.debug(f"created {shm_info.shm.name}") @@ -197,7 +202,7 @@ async def api( # NOTE: SHMContexts are like GraphClients in that their # lifetime is bound to the lifetime of this connection - # but rather than the early return that we have with + # but rather than the early return that we have with # GraphClients, we just await the lease task. # This may be a more readable pattern? # NOTE: With the current shutdown pattern, we cancel @@ -218,20 +223,22 @@ async def api( # Advertise an address the channel should be able to resolve port = pub_info.address.port iface = pub_info.writer.transport.get_extra_info("peername")[0] - pub_addr = Address(iface, port) + pub_addr = Address(iface, port) else: logger.warning(f"Connecting channel requested {type(pub_info)}") except KeyError: - logger.warning(f"Connecting channel requested non-existent publisher {pub_id=}") + logger.warning( + f"Connecting channel requested non-existent publisher {pub_id=}" + ) if pub_addr is None: # FIXME: Channel should not be created; but it feels like we should - # have a better communication protocol to tell the channel what the + # have a better communication protocol to tell the channel what the # error was and deliver a better experience from the client side. # for now, drop connection await close_stream_writer(writer) - return + return writer.write(Command.COMPLETE.value) writer.write(encode_str(str(channel_id))) @@ -243,7 +250,7 @@ async def api( self._handle_client(channel_id, reader, writer) ) - # NOTE: Created a client, must return early + # NOTE: Created a client, must return early # to avoid closing writer return @@ -251,8 +258,8 @@ async def api( # We only want to handle one command at a time async with self._command_lock: if req in [ - Command.SUBSCRIBE.value, - Command.PUBLISH.value, + Command.SUBSCRIBE.value, + Command.PUBLISH.value, ]: client_id = uuid1() topic = await read_str(reader) @@ -281,7 +288,7 @@ async def api( writer.write(Command.COMPLETE.value) - # NOTE: Created a client, must return early + # NOTE: Created a client, must return early # to avoid closing writer return @@ -338,26 +345,29 @@ async def api( logger.debug("GraphServer connection fail mid-command") # NOTE: This prevents code repetition for many graph server commands, but - # when we create GraphClients, their lifecycle is bound to the lifecycle of - # this connection. We do NOT want to close the stream writer if we have + # when we create GraphClients, their lifecycle is bound to the lifecycle of + # this connection. We do NOT want to close the stream writer if we have # created a GraphClient, which requires an early return. Perhaps a different # communication protocol could resolve this await close_stream_writer(writer) async def _handle_client( - self, client_id: UUID, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + self, + client_id: UUID, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, ) -> None: - """ + """ The lifecycle of a graph client is tied to the lifecycle of this TCP connection. We always attempt to read from the reader, and if the connection is ever closed from the other side, reader.read(1) returns empty, we break the loop, and delete - the client from our list of connected clients. Because we're always reading + the client from our list of connected clients. Because we're always reading from the client within this task, we have to handle all client responses within - this task. - + this task. + NOTE: _notify_subscriber requires a COMPLETE once the subscriber has reconfigured its connections. ClientInfo.sync_writer gives us the mechanism by which to block - _notify_subscriber until the COMPLETE is received in this context. + _notify_subscriber until the COMPLETE is received in this context. """ logger.debug(f"Graph Server: Client connected: {client_id}") @@ -391,7 +401,7 @@ async def _notify_subscriber(self, sub: SubscriberInfo) -> None: notify_str = ",".join(pub_ids) writer.write(Command.UPDATE.value) writer.write(encode_str(notify_str)) - + except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"Failed to update Subscriber {sub.id}: {e}") @@ -404,11 +414,9 @@ def _subscribers(self) -> list[SubscriberInfo]: return [ info for info in self.clients.values() if isinstance(info, SubscriberInfo) ] - + def _channels(self) -> list[ChannelInfo]: - return [ - info for info in self.clients.values() if isinstance(info, ChannelInfo) - ] + return [info for info in self.clients.values() if isinstance(info, ChannelInfo)] def _upstream_pubs(self, topic: str) -> list[PublisherInfo]: """Given a topic, return a set of all publisher IDs upstream of that topic""" @@ -603,4 +611,4 @@ async def attach_shm(self, name: str) -> SHMContext: raise ValueError("Invalid SHM Name") shm_name = await read_str(reader) - return SHMContext.attach(shm_name, reader, writer) \ No newline at end of file + return SHMContext.attach(shm_name, reader, writer) diff --git a/src/ezmsg/core/message.py b/src/ezmsg/core/message.py index ca3225d0..fced71d6 100644 --- a/src/ezmsg/core/message.py +++ b/src/ezmsg/core/message.py @@ -25,14 +25,15 @@ def __new__( class Message(ABC, metaclass=MessageMeta): """ Deprecated base class for messages in the ezmsg framework. - + .. deprecated:: Message is deprecated. Use @dataclass decorators instead of inheriting - from ez.Message. For data arrays, use :obj:`ezmsg.util.messages.AxisArray`. - + from ez.Message. For data arrays, use :obj:`ezmsg.util.messages.AxisArray`. + .. note:: This class will issue a DeprecationWarning when instantiated. """ + def __init__(self): warnings.warn( "Message is deprecated. Replace ez.Message with @dataclass decorators", @@ -45,10 +46,9 @@ def __init__(self): class Flag: """ A message with no contents. - + Flag is used as a simple signal message that carries no data, typically used for synchronization or simple event notification. """ ... - diff --git a/src/ezmsg/core/messagecache.py b/src/ezmsg/core/messagecache.py index 4ec1702b..c56b0458 100644 --- a/src/ezmsg/core/messagecache.py +++ b/src/ezmsg/core/messagecache.py @@ -14,7 +14,6 @@ class CacheEntry(typing.NamedTuple): class MessageCache: - _cache: list[CacheEntry | None] def __init__(self, num_buffers: int) -> None: @@ -26,7 +25,7 @@ def _buf_idx(self, msg_id: int) -> int: def __getitem__(self, msg_id: int) -> typing.Any: """ Get a cached object by msg_id - + :param msg_id: Message ID to retreive from cache :type msg_id: int :raises CacheMiss: If this msg_id does not exist in the cache. @@ -35,17 +34,17 @@ def __getitem__(self, msg_id: int) -> typing.Any: if entry is None or entry.msg_id != msg_id: raise CacheMiss return entry.object - + def keys(self) -> list[int]: """ Get a list of current cached msg_ids """ return [entry.msg_id for entry in self._cache if entry is not None] - + def put_local(self, obj: typing.Any, msg_id: int) -> None: """ Put an object with associated msg_id directly into cache - + :param obj: Object to put in cache. :type obj: typing.Any :param msg_id: ID associated with this message/object. @@ -53,10 +52,10 @@ def put_local(self, obj: typing.Any, msg_id: int) -> None: """ self._put( CacheEntry( - object = obj, - msg_id = msg_id, - context = None, - memory = None, + object=obj, + msg_id=msg_id, + context=None, + memory=None, ) ) @@ -66,7 +65,7 @@ def put_from_mem(self, mem: memoryview) -> None: overwriting the existing slot in cache. This method passes the lifecycle of the memoryview to the MessageCache and the memoryview will be properly released by the cache with `free` - + :param mem: Source memoryview containing serialized object. :type from_mem: memoryview :raises UninitializedMemory: If mem buffer is not properly initialized. @@ -74,10 +73,10 @@ def put_from_mem(self, mem: memoryview) -> None: ctx = MessageMarshal.obj_from_mem(mem) self._put( CacheEntry( - object = ctx.__enter__(), - msg_id = MessageMarshal.msg_id(mem), - context = ctx, - memory = mem, + object=ctx.__enter__(), + msg_id=MessageMarshal.msg_id(mem), + context=ctx, + memory=mem, ) ) @@ -97,11 +96,11 @@ def _release(self, buf_idx: int) -> None: self._cache[buf_idx] = None if mem is not None: mem.release() - + def release(self, msg_id: int) -> None: """ Release memory for the entry associated with msg_id - + :param mem: Source memoryview containing serialized object. :type from_mem: memoryview :raises UninitializedMemory: If mem buffer is not properly initialized. @@ -115,7 +114,7 @@ def release(self, msg_id: int) -> None: def clear(self) -> None: """ Release all cached objects - + :param mem: Source memoryview containing serialized object. :type from_mem: memoryview :raises UninitializedMemory: If mem buffer is not properly initialized. diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 28474a0f..15dd180d 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -14,9 +14,9 @@ from .netprotocol import ( Command, Address, - AddressType, - read_str, - read_int, + AddressType, + read_str, + read_int, uint64_to_bytes, encode_str, close_stream_writer, @@ -33,7 +33,7 @@ class Channel: Channel is a "middle-man" that receives messages from a particular Publisher, maintains the message in a MessageCache, and pushes notifications to interested Subscribers in this process. - + Channel primarily exists to reduce redundant message serialization and telemetry. .. note:: @@ -58,10 +58,10 @@ class Channel: _local_backpressure: Backpressure | None def __init__( - self, - id: UUID, - pub_id: UUID, - num_buffers: int, + self, + id: UUID, + pub_id: UUID, + num_buffers: int, shm: SHMContext | None, graph_address: AddressType | None, ) -> None: @@ -105,8 +105,8 @@ async def create( # FIXME: This will happen if the channel requested connection # to a non-existent (or non-publisher) UUID. Ideally GraphServer # would tell us what happened rather than drop connection - raise ValueError(f'failed to create channel {pub_id=}') - + raise ValueError(f"failed to create channel {pub_id=}") + id_str = await read_str(graph_reader) pub_address = await Address.from_stream(graph_reader) @@ -126,29 +126,29 @@ async def create( result = await reader.read(1) if result != Command.COMPLETE.value: - # NOTE: The only reason this would happen is if the + # NOTE: The only reason this would happen is if the # publisher's writer is closed due to a crash or shutdown - raise ValueError(f'failed to create channel {pub_id=}') - + raise ValueError(f"failed to create channel {pub_id=}") + num_buffers = await read_int(reader) - + chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address) chan._graph_task = asyncio.create_task( chan._graph_connection(graph_reader, graph_writer), - name = f'chan-{chan.id}: _graph_connection' + name=f"chan-{chan.id}: _graph_connection", ) chan._pub_writer = writer chan._pub_task = asyncio.create_task( chan._publisher_connection(reader), - name = f'chan-{chan.id}: _publisher_connection' + name=f"chan-{chan.id}: _publisher_connection", ) - logger.debug(f'created channel {chan.id=} {pub_id=} {pub_address=}') + logger.debug(f"created channel {chan.id=} {pub_id=} {pub_address=}") return chan - + def close(self) -> None: """ Mark the Channel for shutdown and resource deallocation @@ -166,7 +166,7 @@ async def wait_closed(self) -> None: await self._graph_task if self.shm is not None: await self.shm.wait_closed() - + async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: @@ -176,7 +176,7 @@ async def _graph_connection( try: while True: cmd = await reader.read(1) - + if not cmd: break @@ -208,14 +208,15 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: shm_name = await read_str(reader) if self.shm is not None and self.shm.name != shm_name: - shm_entries = self.cache.keys() self.cache.clear() self.shm.close() await self.shm.wait_closed() try: - self.shm = await GraphService(self._graph_address).attach_shm(shm_name) + self.shm = await GraphService( + self._graph_address + ).attach_shm(shm_name) except ValueError: logger.info( "Invalid SHM received from publisher; may be dead" @@ -224,7 +225,7 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: for id in shm_entries: self.cache.put_from_mem(self.shm[id % self.num_buffers]) - + assert self.shm is not None assert MessageMarshal.msg_id(self.shm[buf_idx]) == msg_id self.cache.put_from_mem(self.shm[buf_idx]) @@ -256,10 +257,11 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") def _notify_clients(self, msg_id: int) -> bool: - """ notify interested clients and return true if any were notified """ + """notify interested clients and return true if any were notified""" buf_idx = msg_id % self.num_buffers for client_id, queue in self.clients.items(): - if queue is None: continue # queue is none if this is the pub + if queue is None: + continue # queue is none if this is the pub self.backpressure.lease(client_id, buf_idx) queue.put_nowait((self.pub_id, msg_id)) return not self.backpressure.available(buf_idx) @@ -270,18 +272,22 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: .. note:: This command should ONLY be used by Publishers that are in the same process as this Channel. """ if self._local_backpressure is None: - raise ValueError('cannot put_local without access to publisher backpressure (is publisher in same process?)') - + raise ValueError( + "cannot put_local without access to publisher backpressure (is publisher in same process?)" + ) + buf_idx = msg_id % self.num_buffers if self._notify_clients(msg_id): self.cache.put_local(msg, msg_id) self._local_backpressure.lease(self.id, buf_idx) @contextmanager - def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None, None]: + def get( + self, msg_id: int, client_id: UUID + ) -> typing.Generator[typing.Any, None, None]: """ Get a message via a ContextManager - + :param msg_id: Message ID to retreive :type msg_id: int :param client_id: UUID of client retreiving this message for backpressure purposes @@ -290,7 +296,7 @@ def get(self, msg_id: int, client_id: UUID) -> typing.Generator[typing.Any, None :return: A ContextManager for the message (type: Any) :rtype: Generator[Any] """ - + try: yield self.cache[msg_id] finally: @@ -313,9 +319,9 @@ def _acknowledge(self, msg_id: int) -> None: logger.info(f"ack fail: channel:{self.id} -> pub:{self.pub_id}") def register_client( - self, - client_id: UUID, - queue: NotificationQueue | None = None, + self, + client_id: UUID, + queue: NotificationQueue | None = None, local_backpressure: Backpressure | None = None, ) -> None: """ @@ -326,11 +332,11 @@ def register_client( :param queue: The notification queue for the subscribing client :type queue: asyncio.Queue[tuple[UUID, int]] | None :param local_backpressure: The backpressure object for the Publisher if it is in the same process - :type local_backpressure: Backpressure + :type local_backpressure: Backpressure """ self.clients[client_id] = queue if client_id == self.pub_id: - self._local_backpressure = local_backpressure + self._local_backpressure = local_backpressure def unregister_client(self, client_id: UUID) -> None: """ diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index e9f571a0..779fcc81 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -8,21 +8,22 @@ _PREAMBLE = b"EZ" _PREAMBLE_LEN = len(_PREAMBLE) -NO_MESSAGE = _PREAMBLE + (b'\xFF' * 8) + (b'\x00' * 8) +NO_MESSAGE = _PREAMBLE + (b"\xff" * 8) + (b"\x00" * 8) class UndersizedMemory(Exception): """ Exception raised when target memory buffer is too small for serialization. - + Contains the required size needed to successfully serialize the object. """ + req_size: int def __init__(self, *args: object, req_size: int = 0) -> None: """ Initialize UndersizedMemory exception. - + :param args: Exception arguments. :param req_size: Required memory size in bytes. :type req_size: int @@ -31,13 +32,14 @@ def __init__(self, *args: object, req_size: int = 0) -> None: self.req_size = req_size -class UninitializedMemory(Exception): +class UninitializedMemory(Exception): """ Exception raised when attempting to read from uninitialized memory. - + This occurs when trying to deserialize from a memory buffer that doesn't contain valid serialized data with the expected preamble. """ + ... @@ -57,7 +59,7 @@ class Marshal: def to_mem(cls, msg_id: int, obj: Any, mem: memoryview) -> None: """ Serialize an object with message ID into a memory buffer. - + :param msg_id: Unique message identifier. :type msg_id: int :param obj: Object to serialize. @@ -71,7 +73,7 @@ def to_mem(cls, msg_id: int, obj: Any, mem: memoryview) -> None: if total_size >= len(mem): raise UndersizedMemory(req_size=total_size) - + cls._write(mem, header, buffers) @classmethod @@ -92,7 +94,7 @@ def _assert_initialized(cls, raw: memoryview | bytes) -> None: def msg_id(cls, raw: memoryview | bytes) -> int: """ Get msg_id from a buffer; if uninitialized, return None. - + :param mem: buffer to read from. :type mem: memoryview | bytes :return: Message ID of encoded message @@ -102,16 +104,15 @@ def msg_id(cls, raw: memoryview | bytes) -> int: cls._assert_initialized(raw) return bytes_to_uint(raw[_PREAMBLE_LEN : _PREAMBLE_LEN + UINT64_SIZE]) - @classmethod @contextmanager def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: """ Deserialize an object from a memory buffer. - + Provides a context manager for safe access to deserialized objects with automatic cleanup of memory views. - + :param mem: Memory buffer containing serialized object. :type mem: memoryview :return: Context manager yielding the deserialized object. @@ -124,7 +125,7 @@ def obj_from_mem(cls, mem: memoryview) -> Generator[Any, None, None]: num_buffers = bytes_to_uint(mem[sidx : sidx + UINT64_SIZE]) if num_buffers == 0: raise ValueError("invalid message in memory") - + sidx += UINT64_SIZE buf_sizes = [0] * num_buffers for i in range(num_buffers): @@ -152,10 +153,10 @@ def serialize( ) -> Generator[tuple[int, bytes, list[memoryview]], None, None]: """ Serialize an object for network transmission. - + Creates a complete serialization package with header and buffers suitable for network transmission. - + :param msg_id: Unique message identifier. :type msg_id: int :param obj: Object to serialize. @@ -181,7 +182,7 @@ def serialize( def dump(obj: Any) -> list[memoryview]: """ Serialize an object to a list of memory buffers using pickle. - + :param obj: Object to serialize. :type obj: Any :return: List of memory views containing serialized data. @@ -196,7 +197,7 @@ def dump(obj: Any) -> list[memoryview]: def load(buffers: list[memoryview]) -> Any: """ Deserialize an object from a list of memory buffers using pickle. - + :param buffers: List of memory views containing serialized data. :type buffers: list[memoryview] :return: Deserialized object. @@ -208,7 +209,7 @@ def load(buffers: list[memoryview]) -> Any: def copy_obj(cls, from_mem: memoryview, to_mem: memoryview) -> None: """ Copy obj in from_mem (if initialized) to to_mem. - + :param from_mem: Source memory buffer containing serialized object. :type from_mem: memoryview :param to_mem: Target memory buffer for copying. @@ -218,10 +219,9 @@ def copy_obj(cls, from_mem: memoryview, to_mem: memoryview) -> None: msg_id = cls.msg_id(from_mem) with MessageMarshal.obj_from_mem(from_mem) as obj: MessageMarshal.to_mem(msg_id, obj, to_mem) - -# If some other byte-level representation is desired, you can just -# monkeypatch the module at runtime with a different Marhsal subclass +# If some other byte-level representation is desired, you can just +# monkeypatch the module at runtime with a different Marhsal subclass # TODO: This could also be done with environment variables MessageMarshal = Marshal diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index 9fbde198..ee1d903c 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -29,18 +29,18 @@ PUBLISHER_START_PORT_DEFAULT = 25980 GRAPHSERVER_ADDR = os.environ.get( - GRAPHSERVER_ADDR_ENV, - f"{DEFAULT_HOST}:{GRAPHSERVER_PORT_DEFAULT}" + GRAPHSERVER_ADDR_ENV, f"{DEFAULT_HOST}:{GRAPHSERVER_PORT_DEFAULT}" ) class Address(typing.NamedTuple): """ Network address representation with host and port. - + Provides utility methods for address parsing, serialization, and socket binding operations. """ + host: str port: int @@ -48,7 +48,7 @@ class Address(typing.NamedTuple): async def from_stream(cls, reader: asyncio.StreamReader) -> "Address": """ Read an Address from an async stream. - + :param reader: Stream reader to read address string from. :type reader: asyncio.StreamReader :return: Parsed Address instance. @@ -61,7 +61,7 @@ async def from_stream(cls, reader: asyncio.StreamReader) -> "Address": def from_string(cls, address: str) -> "Address": """ Parse an Address from a string representation. - + :param address: Address string in "host:port" format. :type address: str :return: Parsed Address instance. @@ -73,7 +73,7 @@ def from_string(cls, address: str) -> "Address": def to_stream(self, writer: asyncio.StreamWriter) -> None: """ Write this address to an async stream. - + :param writer: Stream writer to send address string to. :type writer: asyncio.StreamWriter """ @@ -82,7 +82,7 @@ def to_stream(self, writer: asyncio.StreamWriter) -> None: def bind_socket(self) -> socket.socket: """ Create and bind a socket to this address. - + :return: Socket bound to this address. :rtype: socket.socket :raises IOError: If no free ports are available. @@ -100,10 +100,11 @@ def __str__(self): class ClientInfo: """ Base information for client connections. - + Tracks client identification, communication writer, and provides synchronized access to the writer for thread-safe operations. """ + id: UUID writer: asyncio.StreamWriter @@ -119,10 +120,10 @@ def set_sync(self) -> None: async def sync_writer(self) -> AsyncGenerator[asyncio.StreamWriter, None]: """ Get synchronized access to the writer. - + Ensures thread-safe access to the stream writer by coordinating access through an asyncio Event mechanism. - + :return: Context manager yielding the synchronized writer. :rtype: collections.abc.AsyncGenerator[asyncio.StreamWriter, None] """ @@ -141,6 +142,7 @@ class PublisherInfo(ClientInfo): Publisher-specific client information. Extends ClientInfo with the publisher's network address. """ + topic: str address: Address @@ -148,24 +150,25 @@ class PublisherInfo(ClientInfo): @dataclass class SubscriberInfo(ClientInfo): """ - Subscriber-specific client information. + Subscriber-specific client information. """ + topic: str @dataclass class ChannelInfo(ClientInfo): """ - Channel-specific client information. + Channel-specific client information. """ - pub_id: UUID + pub_id: UUID def uint64_to_bytes(i: int) -> bytes: """ Convert a 64-bit unsigned integer to bytes. - + :param i: Integer value to convert. :type i: int :return: Byte representation in little-endian format. @@ -177,7 +180,7 @@ def uint64_to_bytes(i: int) -> bytes: def bytes_to_uint(b: bytes) -> int: """ Convert bytes to a 64-bit unsigned integer. - + :param b: Byte data to convert. :type b: bytes :return: Integer value decoded from little-endian bytes. @@ -189,7 +192,7 @@ def bytes_to_uint(b: bytes) -> int: def encode_str(string: str) -> bytes: """ Encode a string with length prefix for network transmission. - + :param string: String to encode. :type string: str :return: Length-prefixed UTF-8 encoded bytes. @@ -203,7 +206,7 @@ def encode_str(string: str) -> bytes: async def read_int(reader: asyncio.StreamReader) -> int: """ Read a 64-bit unsigned integer from an async stream. - + :param reader: Stream reader to read from. :type reader: asyncio.StreamReader :return: Integer value read from stream. @@ -243,16 +246,16 @@ async def close_server(server: Server): class Command(enum.Enum): """ Enumeration of protocol commands for ezmsg network communication. - + Defines all command types used in the ezmsg protocol for graph management, publisher-subscriber communication, and shared memory operations. """ - + @staticmethod def _generate_next_value_(name, start, count, last_values) -> bytes: """ Generate byte values for enum members. - + :param name: Name of the enum member. :type name: str :param start: Starting value (unused). @@ -289,7 +292,7 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SHM_ATTACH = enum.auto() SHUTDOWN = enum.auto() - + CHANNEL = enum.auto() SHM_OK = enum.auto() SHM_ATTACH_FAILED = enum.auto() @@ -304,10 +307,10 @@ def create_socket( ) -> socket.socket: """ Create a socket bound to an available port. - + Attempts to bind to the specified port, or searches for an available port within the given range if no specific port is provided. - + :param host: Host address to bind to (defaults to DEFAULT_HOST). :type host: str | None :param port: Specific port to bind to (if None, searches for available port). @@ -333,7 +336,7 @@ def create_socket( sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.bind((host, port)) return sock - + port = start_port while port <= max_port: if port not in ignore_ports: @@ -350,4 +353,3 @@ def create_socket( port += 1 raise IOError("Failed to bind socket; no free ports") - diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 45e8a721..0094f7d5 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -49,12 +49,13 @@ class PubChannelInfo(ChannelInfo): class Publisher: """ A publisher client for broadcasting messages to subscribers. - + Publisher manages shared memory allocation, connection handling with subscribers, backpressure control, and supports both shared memory and TCP transport methods. Messages are broadcast to all connected subscribers with automatic cleanup and resource management. """ + id: UUID pid: int topic: str @@ -80,7 +81,7 @@ class Publisher: def client_type() -> bytes: """ Get the client type identifier for publishers. - + :return: Command byte identifying this as a publisher client. :rtype: bytes """ @@ -93,7 +94,6 @@ async def create( graph_address: AddressType | None = None, host: str | None = None, port: int | None = None, - buf_size: int = DEFAULT_SHM_SIZE, num_buffers: int = 32, start_paused: bool = False, @@ -101,7 +101,7 @@ async def create( ) -> "Publisher": """ Create a new Publisher instance and register it with the graph server. - + :param topic: The topic this publisher will broadcast to. :type topic: str :param graph_service: Service for graph server communication. @@ -127,13 +127,13 @@ async def create( pub_id = UUID(await read_str(reader)) pub = cls( - id = pub_id, - topic = topic, - shm = shm, - graph_address = graph_address, - num_buffers = num_buffers, - start_paused = start_paused, - force_tcp = force_tcp, + id=pub_id, + topic=topic, + shm=shm, + graph_address=graph_address, + num_buffers=num_buffers, + start_paused=start_paused, + force_tcp=force_tcp, ) start_port = int( @@ -142,33 +142,32 @@ async def create( sock = create_socket(host, port, start_port=start_port) server = await asyncio.start_server(pub._channel_connect, sock=sock) pub._connection_task = asyncio.create_task( - pub._serve_channels(server), - name = f'pub-{pub.id}: {pub.topic}' + pub._serve_channels(server), name=f"pub-{pub.id}: {pub.topic}" ) # Notify GraphServer that our server is up channel_server_address = Address(*sock.getsockname()) channel_server_address.to_stream(writer) - result = await reader.read(1) # channels connect + result = await reader.read(1) # channels connect if result != Command.COMPLETE.value: - logger.warning(f'Could not create publisher {topic=}') + logger.warning(f"Could not create publisher {topic=}") - # Pass off graph connection keep-alive to publisher task + # Pass off graph connection keep-alive to publisher task pub._graph_task = asyncio.create_task( pub._graph_connection(reader, writer), - name = f'pub-{pub.id}: _graph_connection' + name=f"pub-{pub.id}: _graph_connection", ) pub._local_channel = await CHANNELS.register_local_pub( - pub_id = pub.id, - local_backpressure = pub._backpressure, - graph_address = pub._graph_address, + pub_id=pub.id, + local_backpressure=pub._backpressure, + graph_address=pub._graph_address, ) - logger.debug(f'created pub {pub.id=} {topic=} {channel_server_address=}') + logger.debug(f"created pub {pub.id=} {topic=} {channel_server_address=}") return pub - + async def _serve_channels(self, server: asyncio.Server) -> None: try: await server.serve_forever() @@ -192,7 +191,7 @@ def __init__( """ Initialize a Publisher instance. DO NOT USE this constructor to make a Publisher; use `create` instead - + :param id: Unique identifier for this publisher. :type id: UUID :param topic: The topic this publisher broadcasts to. @@ -229,7 +228,7 @@ def log_name(self) -> str: def close(self) -> None: """ Close the publisher and cancel all associated tasks. - + Cancels graph connection, shared memory, connection server, and all subscriber handling tasks. """ @@ -242,7 +241,7 @@ def close(self) -> None: async def wait_closed(self) -> None: """ Wait for all publisher resources to be fully closed. - + Waits for shared memory cleanup, graph connection termination, connection server shutdown, and all subscriber tasks to complete. """ @@ -260,10 +259,10 @@ async def _graph_connection( ) -> None: """ Handle communication with the graph server. - + Processes commands from the graph server including COMPLETE, PAUSE, RESUME, and SYNC operations. - + :param reader: Stream reader for receiving commands from graph server. :type reader: asyncio.StreamReader :param writer: Stream writer for responding to graph server. @@ -303,10 +302,10 @@ async def _channel_connect( ) -> None: """ Handle new subscriber connections. - + Exchanges identification information with connecting subscribers and sets up subscriber handling tasks. - + :param reader: Stream reader for receiving subscriber info. :type reader: asyncio.StreamReader :param writer: Stream writer for sending publisher info. @@ -316,7 +315,7 @@ async def _channel_connect( if len(cmd) == 0: return - + if cmd == Command.CHANNEL.value: channel_id_str = await read_str(reader) channel_id = UUID(channel_id_str) @@ -335,10 +334,10 @@ async def _handle_channel( ) -> None: """ Handle communication with a specific channel. - + Processes acknowledgments from channels and manages backpressure control based on channel feedback. - + :param info: Information about the channel connection. :type info: PubChannelInfo :param reader: Stream reader for receiving channel messages. @@ -368,7 +367,7 @@ async def _handle_channel( async def sync(self) -> None: """ Pause and drain backpressure. - + Temporarily pauses the publisher and waits for all pending messages to be acknowledged by subscribers. """ @@ -379,7 +378,7 @@ async def sync(self) -> None: def running(self) -> bool: """ Check if the publisher is currently running. - + :return: True if publisher is running and accepting broadcasts. :rtype: bool """ @@ -388,7 +387,7 @@ def running(self) -> bool: def pause(self) -> None: """ Pause the publisher to stop broadcasting messages. - + Messages sent to broadcast() will block until resumed. """ self._running.clear() @@ -396,7 +395,7 @@ def pause(self) -> None: def resume(self) -> None: """ Resume the publisher to allow broadcasting messages. - + Unblocks any pending broadcast() calls. """ self._running.set() @@ -404,10 +403,10 @@ def resume(self) -> None: async def broadcast(self, obj: Any) -> None: """ Broadcast a message to all connected subscribers. - + Handles message serialization, shared memory management, transport selection (local/SHM/TCP), and backpressure control automatically. - + :param obj: The object/message to broadcast to subscribers. :type obj: Any """ @@ -426,14 +425,24 @@ async def broadcast(self, obj: Any) -> None: # Get local channel and put variable there for local tx self._local_channel.put_local(self._msg_id, obj) - if self._force_tcp or any(ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()): - with MessageMarshal.serialize(self._msg_id, obj) as (total_size, header, buffers): + if self._force_tcp or any( + ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values() + ): + with MessageMarshal.serialize(self._msg_id, obj) as ( + total_size, + header, + buffers, + ): total_size_bytes = uint64_to_bytes(total_size) - if not self._force_tcp and any(ch.pid != self.pid and ch.shm_ok for ch in self._channels.values()): + if not self._force_tcp and any( + ch.pid != self.pid and ch.shm_ok for ch in self._channels.values() + ): if self._shm.buf_size < total_size: - new_shm = await GraphService(self._graph_address).create_shm(self._num_buffers, total_size * 2) - + new_shm = await GraphService(self._graph_address).create_shm( + self._num_buffers, total_size * 2 + ) + for i in range(self._num_buffers): try: with self._shm.buffer(i, readonly=True) as from_buf: @@ -441,7 +450,7 @@ async def broadcast(self, obj: Any) -> None: MessageMarshal.copy_obj(from_buf, to_buf) except UninitializedMemory: pass - + self._shm.close() await self._shm.wait_closed() self._shm = new_shm @@ -450,26 +459,29 @@ async def broadcast(self, obj: Any) -> None: MessageMarshal._write(mem, header, buffers) for channel in self._channels.values(): - - msg: bytes = b'' + msg: bytes = b"" if self.pid == channel.pid and channel.shm_ok: - continue # Local transmission handled by channel.put + continue # Local transmission handled by channel.put - elif (not self._force_tcp) and self.pid != channel.pid and channel.shm_ok: + elif ( + (not self._force_tcp) + and self.pid != channel.pid + and channel.shm_ok + ): msg = ( - Command.TX_SHM.value + - msg_id_bytes + - encode_str(self._shm.name) + Command.TX_SHM.value + + msg_id_bytes + + encode_str(self._shm.name) ) else: msg = ( - Command.TX_TCP.value + - msg_id_bytes + - total_size_bytes + - header + - b''.join([buffer for buffer in buffers]) + Command.TX_TCP.value + + msg_id_bytes + + total_size_bytes + + header + + b"".join([buffer for buffer in buffers]) ) try: diff --git a/src/ezmsg/core/settings.py b/src/ezmsg/core/settings.py index 1ffad2a0..d8766387 100644 --- a/src/ezmsg/core/settings.py +++ b/src/ezmsg/core/settings.py @@ -18,10 +18,11 @@ class SettingsMeta(ABCMeta): """ Metaclass that automatically applies dataclass decorator to Settings classes. - + This metaclass ensures all Settings subclasses are automatically converted to frozen dataclasses, providing immutability and proper initialization. """ + def __new__( cls, name: str, @@ -31,7 +32,7 @@ def __new__( ) -> type["Settings"]: """ Create a new Settings class with dataclass transformation. - + :param name: Name of the class being created. :type name: str :param bases: Base classes for the new class. diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index 4c5ad140..abfdc2d2 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -1,9 +1,7 @@ import asyncio from collections.abc import Generator import logging -import typing -from uuid import UUID from dataclasses import dataclass, field from contextlib import contextmanager, suppress from multiprocessing import resource_tracker @@ -19,6 +17,7 @@ _std_register = resource_tracker.register + def _ignore_shm(name, rtype): if rtype == "shared_memory": return @@ -29,10 +28,10 @@ def _ignore_shm(name, rtype): def _untracked_shm() -> Generator[None, None, None]: """ Disable SHM tracking within context - https://bugs.python.org/issue38119. - + This context manager temporarily disables shared memory tracking to work around a Python bug where shared memory segments are not properly cleaned up. - + :return: Context manager generator. :rtype: collections.abc.Generator[None, None, None] """ @@ -61,6 +60,7 @@ class SHMContext: This format repeats itself for every buffer in the SharedMemory block. """ + num_buffers: int buf_size: int @@ -71,7 +71,7 @@ class SHMContext: def __init__(self, name: str) -> None: """ Initialize SHMContext by connecting to an existing shared memory segment. - + :param name: The name of the shared memory segment to connect to. :type name: str :raises BufferError: If shared memory segment cannot be accessed. @@ -99,7 +99,7 @@ def attach( ) -> "SHMContext": """ Create a new SHMContext with connection monitoring. - + :param shm_name: Name of the shared memory segment. :type shm_name: str :param reader: Stream reader for connection monitoring. @@ -111,11 +111,10 @@ def attach( """ context = cls(shm_name) context._graph_task = asyncio.create_task( - context._graph_connection(reader, writer), - name=f"{context.name}_monitor" + context._graph_connection(reader, writer), name=f"{context.name}_monitor" ) return context - + async def _graph_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: @@ -131,7 +130,7 @@ async def _graph_connection( def __getitem__(self, idx: int) -> memoryview: """ Get a memory view of a specific buffer in the shared memory segment. - + :param idx: Index of the buffer to access. :type idx: int :return: A memoryview of the buffer. @@ -148,7 +147,7 @@ def buffer( ) -> Generator[memoryview, None, None]: """ Get a memory view of a specific buffer in the shared memory segment. - + :param idx: Index of the buffer to access. :type idx: int :param readonly: Whether to provide read-only access to the buffer. @@ -166,7 +165,7 @@ def buffer( def close(self) -> None: """ Close the shared memory context and cancel monitoring. - + This initiates an asynchronous close operation and cancels the connection monitor task. """ @@ -175,18 +174,18 @@ def close(self) -> None: async def wait_closed(self) -> None: """ Wait for the shared memory context to be fully closed. - + This method waits for the monitoring task to complete, indicating that the connection has been properly terminated. """ with suppress(asyncio.CancelledError): await self._graph_task - + @property def name(self) -> str: """ Get the name of the shared memory segment. - + :return: The shared memory segment name. :rtype: str """ @@ -196,7 +195,7 @@ def name(self) -> str: def size(self) -> int: """ Get the usable size of each buffer (excluding header). - + :return: Buffer size minus 16-byte header. :rtype: int """ @@ -207,33 +206,32 @@ def size(self) -> int: class SHMInfo: """ Information about a shared memory segment and its active leases. - + Tracks the SharedMemory object and manages client connection leases. When all leases are released, the shared memory is automatically cleaned up. """ + shm: SharedMemory leases: set["asyncio.Task[None]"] = field(default_factory=set) @classmethod - def create( - cls, num_buffers: int, buf_size: int - ) -> "SHMInfo": - buf_size += 16 * num_buffers # Repeated header info makes this a bit bigger + def create(cls, num_buffers: int, buf_size: int) -> "SHMInfo": + buf_size += 16 * num_buffers # Repeated header info makes this a bit bigger shm = SharedMemory(size=num_buffers * buf_size, create=True) shm.buf[:] = b"0" * len(shm.buf) # Guarantee zeros shm.buf[0:8] = uint64_to_bytes(num_buffers) shm.buf[8:16] = uint64_to_bytes(buf_size) return cls(shm) - + def lease( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> "asyncio.Task[None]": """ Create a lease for this shared memory segment. - + The lease monitors the client connection and automatically releases the shared memory when the client disconnects. - + :param reader: Stream reader to monitor for client disconnection. :type reader: asyncio.StreamReader :param writer: Stream writer for connection cleanup. diff --git a/src/ezmsg/core/state.py b/src/ezmsg/core/state.py index 385f3520..c298fd7d 100644 --- a/src/ezmsg/core/state.py +++ b/src/ezmsg/core/state.py @@ -7,10 +7,11 @@ class StateMeta(ABCMeta): """ Metaclass that automatically applies dataclass decorator to State classes. - + This metaclass ensures all State subclasses are automatically converted to mutable dataclasses with hash support but no automatic initialization. """ + def __new__( cls, name: str, @@ -20,7 +21,7 @@ def __new__( ) -> type["State"]: """ Create a new State class with dataclass transformation. - + :param name: Name of the class being created. :type name: str :param bases: Base classes for the new class. diff --git a/src/ezmsg/core/stream.py b/src/ezmsg/core/stream.py index 178e746c..2af20f3c 100644 --- a/src/ezmsg/core/stream.py +++ b/src/ezmsg/core/stream.py @@ -7,7 +7,7 @@ class Stream(Addressable): """ Base class for all streams in the ezmsg framework. - + Streams define the communication channels between components, carrying messages of a specific type through the system. @@ -29,7 +29,7 @@ def __repr__(self) -> str: class InputStream(Stream): """ Can be added to any Component as a member variable. Methods may subscribe to it. - + InputStream represents a channel that receives messages from other components. Units can subscribe to InputStreams to process incoming messages. @@ -44,7 +44,7 @@ def __repr__(self) -> str: class OutputStream(Stream): """ Can be added to any Component as a member variable. Methods may publish to it. - + OutputStream represents a channel that sends messages to other components. Units can publish to OutputStreams to send messages through the system. diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index ad645d3d..981233f4 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -25,11 +25,12 @@ class Subscriber: """ A subscriber client for receiving messages from publishers. - + Subscriber manages connections to multiple publishers, handles different transport methods (local, shared memory, TCP), and provides both copying and zero-copy message access patterns with automatic acknowledgment. """ + id: UUID topic: str @@ -50,14 +51,11 @@ class Subscriber: @classmethod async def create( - cls, - topic: str, - graph_address: AddressType | None, - **kwargs + cls, topic: str, graph_address: AddressType | None, **kwargs ) -> "Subscriber": """ Create a new Subscriber instance and register it with the graph server. - + :param topic: The topic this subscriber will listen to. :type topic: str :param graph_service: Service for graph server communication. @@ -73,12 +71,12 @@ async def create( writer.write(encode_str(topic)) sub_id_str = await read_str(reader) sub_id = UUID(sub_id_str) - + sub = cls(sub_id, topic, graph_address, **kwargs) sub._graph_task = asyncio.create_task( sub._graph_connection(reader, writer), - name = f'sub-{sub.id}: _graph_connection' + name=f"sub-{sub.id}: _graph_connection", ) # FIXME: We need to wait for _graph_task to service an UPDATE @@ -86,16 +84,12 @@ async def create( # subscriber ready for recv. await sub._initialized.wait() - logger.debug(f'created sub {sub.id=} {topic=}') + logger.debug(f"created sub {sub.id=} {topic=}") return sub def __init__( - self, - id: UUID, - topic: str, - graph_address: AddressType | None, - **kwargs + self, id: UUID, topic: str, graph_address: AddressType | None, **kwargs ) -> None: """ Initialize a Subscriber instance. @@ -122,7 +116,7 @@ def __init__( def close(self) -> None: """ Close the subscriber and cancel all associated tasks. - + Cancels graph connection, all publisher connection tasks, and closes all shared memory contexts. """ @@ -131,7 +125,7 @@ def close(self) -> None: async def wait_closed(self) -> None: """ Wait for all subscriber resources to be fully closed. - + Waits for graph connection termination, all publisher connection tasks to complete, and all shared memory contexts to close. """ @@ -143,10 +137,10 @@ async def _graph_connection( ) -> None: """ Handle communication with the graph server. - + Processes commands from the graph server including COMPLETE and UPDATE operations for managing publisher connections. - + :param reader: Stream reader for receiving commands from graph server. :type reader: asyncio.StreamReader :param writer: Stream writer for responding to graph server. @@ -172,23 +166,27 @@ async def _graph_connection( # event, which we wait on in Subscriber.create, releasing the # block and returning a fully connected/ready subscriber. # While this currently works, its non-obvious what this - # accomplishes and why its implemented this way. I just + # accomplishes and why its implemented this way. I just # wasted 2 hours removing this seemingly un-necessary event - # to introduce a bug that is only resolved by reintroducing + # to introduce a bug that is only resolved by reintroducing # the event. Some thought should be put into replacing the # bespoke communication protocol with something a bit more # standard (JSON RPC?) with a more common/logical pattern - # for creating/handling comms. We will probably want to + # for creating/handling comms. We will probably want to # keep the bespoke comms for the publisher->channel link # as that is in the hot path. self._initialized.set() elif cmd == Command.UPDATE.value: update = await read_str(reader) - pub_ids = set([UUID(id) for id in update.split(',')]) if update else set() + pub_ids = ( + set([UUID(id) for id in update.split(",")]) if update else set() + ) for pub_id in set(pub_ids - self._cur_pubs): - channel = await CHANNELS.register(pub_id, self.id, self._incoming, self._graph_address) + channel = await CHANNELS.register( + pub_id, self.id, self._incoming, self._graph_address + ) self._channels[pub_id] = channel for pub_id in set(self._cur_pubs - pub_ids): @@ -204,7 +202,9 @@ async def _graph_connection( ) except (ConnectionResetError, BrokenPipeError): - logger.debug(f"Subscriber {self.topic}({self.id}) lost connection to graph server") + logger.debug( + f"Subscriber {self.topic}({self.id}) lost connection to graph server" + ) finally: for pub_id in self._channels: @@ -214,10 +214,10 @@ async def _graph_connection( async def recv(self) -> typing.Any: """ Receive the next message with a deep copy. - + This method creates a deep copy of the received message, allowing safe modification without affecting the original cached message. - + :return: Deep copy of the received message. :rtype: typing.Any """ @@ -230,11 +230,11 @@ async def recv(self) -> typing.Any: async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: """ Receive the next message with zero-copy access. - + This context manager provides direct access to the cached message without copying. The message should not be modified or stored beyond the context manager's scope. - + :return: Context manager yielding the received message. :rtype: collections.abc.AsyncGenerator[typing.Any, None] """ @@ -242,5 +242,3 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: with self._channels[pub_id].get(msg_id, self.id) as msg: yield msg - - diff --git a/src/ezmsg/core/test.py b/src/ezmsg/core/test.py index ce1e705b..48ec4d95 100644 --- a/src/ezmsg/core/test.py +++ b/src/ezmsg/core/test.py @@ -1,15 +1,18 @@ - # tutorial_pipeline.py +# tutorial_pipeline.py import ezmsg.core as ez from dataclasses import dataclass from collections.abc import AsyncGenerator + class CountSettings(ez.Settings): iterations: int + @dataclass class CountMessage: value: int + class Count(ez.Unit): SETTINGS = CountSettings @@ -24,6 +27,7 @@ async def count(self) -> AsyncGenerator: raise ez.NormalTermination + class AddOne(ez.Unit): INPUT_COUNT = ez.InputStream(CountMessage) OUTPUT_PLUS_ONE = ez.OutputStream(CountMessage) @@ -33,6 +37,7 @@ class AddOne(ez.Unit): async def on_message(self, message) -> AsyncGenerator: yield self.OUTPUT_PLUS_ONE, CountMessage(value=message.value + 1) + class PrintValue(ez.Unit): INPUT = ez.InputStream(CountMessage) @@ -40,13 +45,14 @@ class PrintValue(ez.Unit): async def on_message(self, message) -> None: print(message.value) + components = { "COUNT": Count(settings=CountSettings(iterations=10)), "ADD_ONE": AddOne(), - "PRINT": PrintValue() + "PRINT": PrintValue(), } connections = ( (components["COUNT"].OUTPUT_COUNT, components["ADD_ONE"].INPUT_COUNT), - (components["ADD_ONE"].OUTPUT_PLUS_ONE, components["PRINT"].INPUT) + (components["ADD_ONE"].OUTPUT_PLUS_ONE, components["PRINT"].INPUT), ) -ez.run(components=components, connections=connections) \ No newline at end of file +ez.run(components=components, connections=connections) diff --git a/src/ezmsg/core/unit.py b/src/ezmsg/core/unit.py index 045019a1..d957d06c 100644 --- a/src/ezmsg/core/unit.py +++ b/src/ezmsg/core/unit.py @@ -62,11 +62,11 @@ def __init__( class Unit(Component, metaclass=UnitMeta): """ Represents a single step in the graph. - + Units can subscribe, publish, and have tasks. Units are the fundamental building blocks of ezmsg applications that perform actual computation and message processing. To create a Unit, inherit from the Unit class. - + :param settings: Optional settings object for unit configuration :type settings: Settings | None """ @@ -97,11 +97,11 @@ async def setup(self) -> None: async def initialize(self) -> None: """ Runs when the Unit is instantiated. - + This is called from within the same process this unit will live in. - This lifecycle hook can be overridden. It can be run as async functions + This lifecycle hook can be overridden. It can be run as async functions by simply adding the async keyword when overriding. - + This method is where you should initialize your unit's state and prepare for message processing. """ @@ -110,11 +110,11 @@ async def initialize(self) -> None: async def shutdown(self) -> None: """ Runs when the Unit terminates. - + This is called from within the same process this unit will live in. - This lifecycle hook can be overridden. It can be run as async functions + This lifecycle hook can be overridden. It can be run as async functions by simply adding the async keyword when overriding. - + This method is where you should clean up resources and perform any necessary shutdown procedures. """ @@ -124,7 +124,7 @@ async def shutdown(self) -> None: def publisher(stream: OutputStream): """ A decorator for a method that publishes to a stream in the task/messaging thread. - + An async function will yield messages on the designated :obj:`OutputStream`. A function can have both ``@subscriber`` and ``@publisher`` decorators. @@ -135,7 +135,7 @@ def publisher(stream: OutputStream): :raises ValueError: If stream is not an OutputStream Example usage: - + .. code-block:: python from collections.abc import AsyncGenerator @@ -163,8 +163,8 @@ def pub_factory(func): def subscriber(stream: InputStream, zero_copy: bool = False): """ A decorator for a method that subscribes to a stream in the task/messaging thread. - - An async function will run once per message received from the :obj:`InputStream` + + An async function will run once per message received from the :obj:`InputStream` it subscribes to. A function can have both ``@subscriber`` and ``@publisher`` decorators. :param stream: The input stream to subscribe to @@ -203,7 +203,7 @@ def sub_factory(func): def main(func: Callable): """ A decorator which designates this function to run as the main thread for this Unit. - + A Unit may only have one main function. The main function runs independently of the message processing and is typically used for initialization, background processing, or cleanup tasks. @@ -220,7 +220,7 @@ def main(func: Callable): def timeit(func: Callable): """ A decorator that logs the execution time of the decorated function. - + ezmsg will log the amount of time this function takes to execute to the ezmsg logger. This is useful for performance monitoring and optimization. The execution time is logged in milliseconds. @@ -252,7 +252,7 @@ def wrapper(self, *args, **kwargs): def thread(func: Callable): """ A decorator which designates this function to run as a background thread for this Unit. - + Thread functions run concurrently with the main message processing and can be used for background tasks, monitoring, or other concurrent operations. @@ -268,7 +268,7 @@ def thread(func: Callable): def task(func: Callable): """ A decorator which designates this function to run as a task in the task/messaging thread. - + Task functions are part of the main message processing pipeline and are executed within the unit's primary execution context. @@ -284,7 +284,7 @@ def task(func: Callable): def process(func: Callable): """ A decorator which designates this function to run in its own process. - + Process functions run in separate processes for isolation and can be used for CPU-intensive operations or when process isolation is required. diff --git a/src/ezmsg/core/util.py b/src/ezmsg/core/util.py index 54cae73a..3bb168f8 100644 --- a/src/ezmsg/core/util.py +++ b/src/ezmsg/core/util.py @@ -1,5 +1,4 @@ from collections.abc import Mapping -import sys import typing @@ -9,7 +8,7 @@ def is_dict_like(value: typing.Any) -> typing.TypeGuard[Mapping]: """ Check if a value behaves like a dictionary. - + This function checks if the value has the basic dictionary interface by verifying it has 'keys' and '__getitem__' attributes. @@ -29,7 +28,7 @@ def either_dict_or_kwargs( """ Handle flexible argument passing patterns for functions that accept either positional dict or keyword arguments. - + This utility function helps implement the common pattern where a function can accept either a dictionary as the first argument or keyword arguments, but not both. diff --git a/src/ezmsg/util/__init__.py b/src/ezmsg/util/__init__.py index ccc806ea..c722ac66 100644 --- a/src/ezmsg/util/__init__.py +++ b/src/ezmsg/util/__init__.py @@ -6,7 +6,7 @@ Key modules: - :mod:`ezmsg.util.debuglog`: Debug logging utilities -- :mod:`ezmsg.util.generator`: Generator-based message processing decorators +- :mod:`ezmsg.util.generator`: Generator-based message processing decorators - :mod:`ezmsg.util.messagecodec`: JSON encoding/decoding for message logging - :mod:`ezmsg.util.messagegate`: Message flow control and gating - :mod:`ezmsg.util.messagelogger`: File-based message logging diff --git a/src/ezmsg/util/messagecodec.py b/src/ezmsg/util/messagecodec.py index 779be16a..2e3db03f 100644 --- a/src/ezmsg/util/messagecodec.py +++ b/src/ezmsg/util/messagecodec.py @@ -10,6 +10,7 @@ The MessageEncoder and MessageDecoder classes handle automatic conversion between Python objects and JSON representations suitable for file logging. """ + from collections.abc import Iterable, Generator import json import pickle @@ -44,7 +45,7 @@ class LogStart: ... def type_str(obj: typing.Any) -> str: """ Get a string representation of an object's type for serialization. - + :param obj: Object to get type string for :type obj: typing.Any :return: String representation in format 'module:qualname' @@ -58,7 +59,7 @@ def type_str(obj: typing.Any) -> str: def import_type(typestr: str) -> type: """ Import a type from a string representation. - + :param typestr: String representation in format 'module:qualname' :type typestr: str :return: The imported type @@ -78,13 +79,14 @@ def import_type(typestr: str) -> type: class MessageEncoder(json.JSONEncoder): """ JSON encoder for ezmsg messages with support for dataclasses, numpy arrays, and arbitrary objects. - + This encoder extends the standard JSON encoder to handle: - + - Dataclass objects (serialized as dictionaries with type information) - NumPy arrays (serialized as base64-encoded data with metadata) - Other objects via pickle (as fallback) """ + def default(self, o: typing.Any): if is_dataclass(o): return { @@ -116,12 +118,13 @@ def default(self, o: typing.Any): class StampedMessage(typing.NamedTuple): """ A message with an associated timestamp. - + :param msg: The message object :type msg: typing.Any :param timestamp: Optional timestamp for the message :type timestamp: float | None """ + msg: typing.Any timestamp: float | None @@ -129,10 +132,10 @@ class StampedMessage(typing.NamedTuple): def _object_hook(obj: dict[str, typing.Any]) -> typing.Any: """ JSON object hook for decoding ezmsg messages. - + Handles reconstruction of dataclasses, numpy arrays, and pickled objects from their JSON representations. - + :param obj: Dictionary from JSON decoder :type obj: dict[str, typing.Any] :return: Reconstructed object @@ -145,9 +148,7 @@ def _object_hook(obj: dict[str, typing.Any]) -> typing.Any: if obj_type is not None: if np and obj_type == NDARRAY_TYPE: data_bytes: str | None = obj.get(NDARRAY_DATA) - data_shape: Iterable[int] | None = obj.get( - NDARRAY_SHAPE - ) + data_shape: Iterable[int] | None = obj.get(NDARRAY_SHAPE) data_dtype: npt.DTypeLike | None = obj.get(NDARRAY_DTYPE) if ( @@ -175,22 +176,21 @@ def _object_hook(obj: dict[str, typing.Any]) -> typing.Any: class MessageDecoder(json.JSONDecoder): """ JSON decoder for ezmsg messages. - + Automatically reconstructs dataclasses, numpy arrays, and pickled objects from their JSON representations using the _object_hook function. """ + def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=_object_hook, *args, **kwargs) - - def message_log( fname: Path, return_object: bool = True ) -> Generator[typing.Any, None, None]: """ Generator function to read messages from a log file created by MessageLogger. - + :param fname: Path to the log file :type fname: Path :param return_object: If True, yield only the message objects; if False, yield complete log entries diff --git a/src/ezmsg/util/messagegate.py b/src/ezmsg/util/messagegate.py index 5d9e35c4..379c95b9 100644 --- a/src/ezmsg/util/messagegate.py +++ b/src/ezmsg/util/messagegate.py @@ -9,7 +9,7 @@ class GateMessage: """ Send this message to ``INPUT_GATE`` to open or close the gate. - + :param open: True to open the gate (allow messages), False to close it (discard messages) :type open: bool """ @@ -66,7 +66,7 @@ async def initialize(self) -> None: def set_gate(self, set_open: bool) -> None: """ Set the gate open/closed state and reset message counter. - + :param set_open: True to open gate, False to close it :type set_open: bool """ diff --git a/src/ezmsg/util/messagelogger.py b/src/ezmsg/util/messagelogger.py index d91b9fb0..cd713f6c 100644 --- a/src/ezmsg/util/messagelogger.py +++ b/src/ezmsg/util/messagelogger.py @@ -16,7 +16,7 @@ def log_object(obj: typing.Any) -> str: """ Convert an object to a JSON string with timestamp for logging. - + :param obj: Object to convert to log string :type obj: typing.Any :return: JSON string containing timestamp and object @@ -86,7 +86,7 @@ class MessageLogger(ez.Unit): def open_file(self, filepath: Path) -> Path | None: """ Open a file for message logging. - + :param filepath: Path to the file to open :type filepath: Path :return: File path if file successfully opened, otherwise None @@ -109,7 +109,7 @@ def open_file(self, filepath: Path) -> Path | None: def close_file(self, filepath: Path) -> Path | None: """ Close a file that was being used for message logging. - + :param filepath: Path to the file to close :type filepath: Path :return: File path if file successfully closed, otherwise None diff --git a/src/ezmsg/util/messages/axisarray.py b/src/ezmsg/util/messages/axisarray.py index 6d627538..5324cc57 100644 --- a/src/ezmsg/util/messages/axisarray.py +++ b/src/ezmsg/util/messages/axisarray.py @@ -12,7 +12,6 @@ try: import numpy as np import numpy.typing as npt - import numpy.lib.stride_tricks as nps except ModuleNotFoundError: ez.logger.error( 'Install ezmsg with the AxisArray extra:pip install "ezmsg[AxisArray]"' @@ -36,6 +35,7 @@ @dataclass class AxisBase(ABC): """Abstract base class for axes types used by AxisArray.""" + unit: str = "" @typing.overload @@ -50,13 +50,13 @@ def value(self, x): @dataclass class LinearAxis(AxisBase): """ - An axis implementation for sparse axes with regular intervals between elements. - + An axis implementation for sparse axes with regular intervals between elements. + It is called "linear" because it provides a simple linear mapping between indices and element values: value = (index * gain) + offset. - A typical example is a time axis (TimeAxis), with regular sampling rate. - + A typical example is a time axis (TimeAxis), with regular sampling rate. + :param gain: Step size (scaling factor) for the linear axis :type gain: float :param offset: The offset (value) of the first sample @@ -64,6 +64,7 @@ class LinearAxis(AxisBase): :param unit: The unit of measurement for this axis (inherited from AxisBase) :type unit: str """ + gain: float = 1.0 offset: float = 0.0 @@ -74,7 +75,7 @@ def value(self, x: npt.NDArray[np.int_]) -> npt.NDArray[np.float64]: ... def value(self, x): """ Convert index(es) to axis value(s) using the linear transformation. - + :param x: Index or array of indices to convert :type x: int | npt.NDArray[:class:`numpy.int_`] :return: Corresponding axis value(s) @@ -105,7 +106,7 @@ def index(self, v, fn=np.rint): def create_time_axis(cls, fs: float, offset: float = 0.0) -> "LinearAxis": """ Convenience method to construct a LinearAxis for time. - + :param fs: Sampling frequency in Hz :type fs: float :param offset: Time offset in seconds (default: 0.0) @@ -120,17 +121,18 @@ def create_time_axis(cls, fs: float, offset: float = 0.0) -> "LinearAxis": class ArrayWithNamedDims: """ Base class for arrays with named dimensions. - + This class provides a foundation for arrays where each dimension has a name, enabling more intuitive array manipulation and access patterns. - + :param data: The underlying numpy array data :type data: npt.NDArray :param dims: List of dimension names, must match the array's number of dimensions :type dims: list[str] - + :raises ValueError: If dims length doesn't match data.ndim or contains duplicate names """ + data: npt.NDArray dims: list[str] @@ -154,10 +156,10 @@ def __eq__(self, other): class CoordinateAxis(AxisBase, ArrayWithNamedDims): """ An axis implementation that uses explicit coordinate values stored in an array. - + This class allows for non-linear or irregularly spaced coordinate systems by storing the actual coordinate values in a data array. - + Inherits from both AxisBase and ArrayWithNamedDims, combining axis functionality with named dimension support. @@ -168,6 +170,7 @@ class CoordinateAxis(AxisBase, ArrayWithNamedDims): :param dims: List of dimension names, must match the array's number of dimensions (inherited from ArrayWithNamedDims) :type dims: list[str] """ + @typing.overload def value(self, x: int) -> typing.Any: ... @typing.overload @@ -175,7 +178,7 @@ def value(self, x: npt.NDArray[np.int_]) -> npt.NDArray: ... def value(self, x): """ Get coordinate value(s) at the given index(es). - + :param x: Index or array of indices to lookup :type x: int | npt.NDArray[:class:`numpy.int_`] :return: Coordinate value(s) at the specified index(es) @@ -188,15 +191,15 @@ def value(self, x): class AxisArray(ArrayWithNamedDims): """ A lightweight message class comprising a numpy ndarray and its metadata. - + AxisArray extends ArrayWithNamedDims to provide a complete data structure for scientific computing with named dimensions, axis coordinate systems, and metadata. It's designed to be similar to xarray.DataArray but optimized for message passing in streaming applications. - + :param data: The underlying numpy array data (inherited from ArrayWithNamedDims) :type data: npt.NDArray - :param dims: List of dimension names (inherited from ArrayWithNamedDims) + :param dims: List of dimension names (inherited from ArrayWithNamedDims) :type dims: list[str] :param axes: Dictionary mapping dimension names to their axis coordinate systems :type axes: dict[str, AxisBase] @@ -205,6 +208,7 @@ class AxisArray(ArrayWithNamedDims): :param key: Optional key identifier for this array, typically used to specify source device (default is empty string) :type key: str """ + axes: dict[str, AxisBase] = field(default_factory=dict) attrs: dict[str, typing.Any] = field(default_factory=dict) key: str = "" @@ -216,7 +220,7 @@ def __eq__(self, other): # returns NotImplemented if classes aren't equal. Unintuitively, # NotImplemented seems to evaluate as 'True' in an if statement. equal = super().__eq__(other) - if equal != True: + if equal is not True: return equal # checks for AxisArray fields @@ -233,10 +237,11 @@ def __eq__(self, other): class Axis(LinearAxis): """ Deprecated alias for LinearAxis. - + .. deprecated:: 3.6.0 Use :class:`LinearAxis` instead. """ + def __post_init__(self) -> None: warnings.warn( "AxisArray.Axis is a deprecated alias for LinearAxis", @@ -263,19 +268,20 @@ def index(self, val): class AxisInfo: """ Container for axis information including the axis object, index, and size. - + This class provides a convenient way to access both the axis coordinate system and metadata about where that axis appears in the array structure. - + :param axis: The axis coordinate system :type axis: AxisBase :param idx: The index of this axis in the AxisArray object's dimension list - :type idx: int + :type idx: int :param size: The size of this dimension (stored as None for CoordinateAxis which determines size from data) :type size: int | None - + :raises ValueError: If size rules are violated for the axis type """ + axis: AxisBase idx: int # TODO (kpilch): rename this to _size as preferred usage is len(obj), not obj.size @@ -300,7 +306,7 @@ def __len__(self) -> int: def indices(self) -> npt.NDArray[np.int_]: """ Get array of all valid indices for this axis. - + :return: Array of indices from 0 to len(self)-1 :rtype: npt.NDArray[:class:`numpy.int_`] """ @@ -310,7 +316,7 @@ def indices(self) -> npt.NDArray[np.int_]: def values(self) -> npt.NDArray: """ Get array of coordinate values for all indices of this axis. - + :return: Array of coordinate values computed from axis.value(indices) :rtype: npt.NDArray """ @@ -323,13 +329,13 @@ def isel( ) -> T: """ Select data using integer-based indexing along specified dimensions. - + This method allows for flexible indexing using integers, slices, or arrays of integers to select subsets of the data along named dimensions. - + :param indexers: Dictionary of {dimension_name: indexer} pairs :type indexers: typing.Any | None - :param indexers_kwargs: Alternative way to specify indexers as keyword arguments + :param indexers_kwargs: Alternative way to specify indexers as keyword arguments :type indexers_kwargs: typing.Any :return: New AxisArray with selected data :rtype: T @@ -368,14 +374,14 @@ def sel( ) -> T: """ Select data using label-based indexing along specified dimensions. - + This method allows selection using real-world coordinate values rather than integer indices. Currently supports only slice objects and LinearAxis. - + :param indexers: Dictionary of {dimension_name: slice_indexer} pairs :type indexers: typing.Any | None :param indexers_kwargs: Alternative way to specify indexers as keyword arguments - :type indexers_kwargs: typing.Any + :type indexers_kwargs: typing.Any :return: New AxisArray with selected data :rtype: T :raises ValueError: If indexer is not a slice or axis is not a LinearAxis @@ -400,7 +406,7 @@ def sel( def shape(self) -> tuple[int, ...]: """ Shape of data. - + :return: Tuple representing the shape of the underlying data array :rtype: tuple[int, ...] """ @@ -409,10 +415,10 @@ def shape(self) -> tuple[int, ...]: def to_xr_dataarray(self) -> "DataArray": """ Convert this AxisArray to an xarray DataArray. - + This method creates an xarray DataArray with equivalent data, coordinates, dimensions, and attributes. Useful for interoperability with the xarray ecosystem. - + :return: xarray DataArray representation of this AxisArray :rtype: xarray.DataArray """ @@ -430,7 +436,7 @@ def to_xr_dataarray(self) -> "DataArray": def ax(self, dim: str | int) -> AxisInfo: """ Get AxisInfo for a specified dimension. - + :param dim: Dimension name or index :type dim: str | int :return: AxisInfo containing axis, index, and size information @@ -445,7 +451,7 @@ def ax(self, dim: str | int) -> AxisInfo: def get_axis(self, dim: str | int) -> AxisBase: """ Get the axis coordinate system for a specified dimension. - + :param dim: Dimension name or index :type dim: str | int :return: The axis coordinate system (defaults to LinearAxis if not specified) @@ -461,7 +467,7 @@ def get_axis(self, dim: str | int) -> AxisBase: def get_axis_name(self, dim: int) -> str: """ Get the dimension name for a given axis index. - + :param dim: The axis index :type dim: int :return: The dimension name @@ -472,7 +478,7 @@ def get_axis_name(self, dim: int) -> str: def get_axis_idx(self, dim: str) -> int: """ Get the axis index for a given dimension name. - + :param dim: The dimension name :type dim: str :return: The axis index @@ -487,7 +493,7 @@ def get_axis_idx(self, dim: str) -> int: def axis_idx(self, dim: str | int) -> int: """ Get the axis index for a given dimension name or pass through if already an int. - + :param dim: Dimension name or index :type dim: str | int :return: The axis index @@ -502,7 +508,7 @@ def _axis_idx(self, dim: str | int) -> int: def as2d(self, dim: str | int) -> npt.NDArray: """ Get a 2D view of the data with the specified dimension as the first axis. - + :param dim: Dimension name or index to move to first axis :type dim: str | int :return: 2D array view with shape (dim_size, remaining_elements) @@ -510,15 +516,13 @@ def as2d(self, dim: str | int) -> npt.NDArray: """ return as2d(self.data, self.axis_idx(dim), xp=get_namespace(self.data)) - def iter_over_axis( - self: T, axis: str | int - ) -> Generator[T, None, None]: + def iter_over_axis(self: T, axis: str | int) -> Generator[T, None, None]: """ Iterate over slices along the specified axis. - + Yields AxisArray objects for each slice along the given axis, with that dimension removed from the resulting arrays. - + :param axis: Dimension name or index to iterate over :type axis: str | int :yields: AxisArray objects for each slice along the axis @@ -540,16 +544,14 @@ def iter_over_axis( yield it_aa @contextmanager - def view2d( - self, dim: str | int - ) -> Generator[npt.NDArray, None, None]: + def view2d(self, dim: str | int) -> Generator[npt.NDArray, None, None]: """ Context manager providing a 2D view of the data. - + Yields a 2D array view with the specified dimension as the first axis. Changes to the yielded array may be reflected in the original data. - - :param dim: Dimension name or index to move to first axis + + :param dim: Dimension name or index to move to first axis :type dim: str | int :yields: 2D array view with shape (dim_size, remaining_elements) :rtype: collections.abc.Generator[npt.NDArray, None, None] @@ -560,8 +562,8 @@ def view2d( def shape2d(self, dim: str | int) -> tuple[int, int]: """ Get the 2D shape when viewing data with specified dimension first. - - :param dim: Dimension name or index + + :param dim: Dimension name or index :type dim: str | int :return: Tuple of (dim_size, remaining_elements) :rtype: tuple[int, int] @@ -577,7 +579,7 @@ def concatenate( ) -> T: """ Concatenate multiple AxisArray objects along a specified dimension. - + :param aas: Variable number of AxisArray objects to concatenate :type aas: T :param dim: Dimension name along which to concatenate @@ -637,7 +639,7 @@ def transpose( ) -> T: """ Transpose (reorder) the dimensions of an AxisArray. - + :param aa: The AxisArray to transpose :type aa: T :param dims: New dimension order (names or indices). If None, reverses all dimensions @@ -653,9 +655,7 @@ def transpose( return replace(aa, data=new_data, dims=new_dims, axes=aa.axes) -def slice_along_axis( - in_arr: npt.NDArray, sl: slice | int, axis: int -) -> npt.NDArray: +def slice_along_axis(in_arr: npt.NDArray, sl: slice | int, axis: int) -> npt.NDArray: """ Slice the input array along a specified axis using the given slice object or integer index. Integer arguments to `sl` will cause the sliced dimension to be dropped. @@ -689,7 +689,7 @@ def sliding_win_oneaxis( ) -> npt.NDArray: """ Generates a view of an array using a sliding window of specified length along a specified axis of the input array. - This is a slightly optimized version of nps.sliding_window_view with a few important differences: + This is a slightly optimized version of numpy.lib.stride_tricks.sliding_window_view with a few important differences: - This only accepts a single nwin and a single axis, thus we can skip some checks. - The new `win` axis precedes immediately the original target axis, unlike sliding_window_view where the @@ -757,7 +757,7 @@ def _as2d( ) -> tuple[npt.NDArray, tuple[int, ...]]: """ Internal helper function to reshape array to 2D with specified axis first. - + :param in_arr: Input array to be reshaped :type in_arr: npt.NDArray :param axis: Axis to move to first position (default: 0) @@ -781,7 +781,7 @@ def _as2d( def as2d(in_arr: npt.NDArray, axis: int = 0, *, xp) -> npt.NDArray: """ Reshape array to 2D with specified axis first. - + :param in_arr: Input array :type in_arr: npt.NDArray :param axis: Axis to move to first position (default: 0) @@ -795,9 +795,7 @@ def as2d(in_arr: npt.NDArray, axis: int = 0, *, xp) -> npt.NDArray: @contextmanager -def view2d( - in_arr: npt.NDArray, axis: int = 0 -) -> Generator[npt.NDArray, None, None]: +def view2d(in_arr: npt.NDArray, axis: int = 0) -> Generator[npt.NDArray, None, None]: """ Context manager providing 2D view of the array, no matter what input dimensionality is. Yields a view of underlying data when possible, changes to data in yielded @@ -806,8 +804,8 @@ def view2d( NOTE: In practice, I'm not sure this is very useful because it requires modifying the numpy array data in-place, which limits its application to zero-copy messages - NOTE: The context manager allows the use of `with` when calling `view2d`. - + NOTE: The context manager allows the use of `with` when calling `view2d`. + :param in_arr: Input array to be viewed as 2D :type in_arr: npt.NDArray :param axis: Dimension index to move to first axis (Default = 0) @@ -829,7 +827,7 @@ def view2d( def shape2d(arr: npt.NDArray, axis: int = 0) -> tuple[int, int]: """ Calculate the 2D shape when viewing array with specified axis first. - + :param arr: Input array :type arr: npt.NDArray :param axis: Axis to move to first position (default: 0) diff --git a/src/ezmsg/util/messages/chunker.py b/src/ezmsg/util/messages/chunker.py index 402fd2bc..a7bacbab 100644 --- a/src/ezmsg/util/messages/chunker.py +++ b/src/ezmsg/util/messages/chunker.py @@ -1,7 +1,6 @@ import asyncio from collections.abc import Generator, AsyncGenerator import traceback -import typing import numpy as np import numpy.typing as npt @@ -21,7 +20,7 @@ def array_chunker( """ Create a generator that yields AxisArrays containing chunks of an array along a specified axis. - + The generator should be useful for quick offline analyses, tests, or examples. This generator probably is not useful for online streaming applications. @@ -66,7 +65,7 @@ def array_chunker( class ArrayChunkerSettings(ez.Settings): """ Settings for ArrayChunker unit. - + Configuration for chunking array data along a specified axis with timing information. :param data: An array_like object to iterate over, chunk-by-chunk. @@ -80,6 +79,7 @@ class ArrayChunkerSettings(ez.Settings): :param tzero: The time offset of the first chunk. Will only be used to make the time axis. :type tzero: float """ + data: npt.ArrayLike chunk_len: int axis: int = 0 @@ -90,10 +90,11 @@ class ArrayChunkerSettings(ez.Settings): class ArrayChunker(ez.Unit): """ Unit for chunking array data along a specified axis. - + Converts array data into sequential chunks along a specified axis, with proper timing axis information for streaming applications. """ + SETTINGS = ArrayChunkerSettings STATE = GenState @@ -103,7 +104,7 @@ class ArrayChunker(ez.Unit): async def initialize(self) -> None: """ Initialize the ArrayChunker unit. - + Sets up the generator for chunking operations based on current settings. """ self.construct_generator() @@ -111,7 +112,7 @@ async def initialize(self) -> None: def construct_generator(self): """ Construct the chunking generator with current settings. - + Creates a new array_chunker generator instance using the unit's settings. """ self.STATE.gen = array_chunker( @@ -126,7 +127,7 @@ def construct_generator(self): async def on_settings(self, msg: ez.Settings) -> None: """ Handle incoming settings updates. - + :param msg: New settings to apply. :type msg: ez.Settings """ @@ -137,10 +138,10 @@ async def on_settings(self, msg: ez.Settings) -> None: async def send_chunk(self) -> AsyncGenerator: """ Publisher method that yields data chunks. - + Continuously yields chunks from the generator until exhausted, with proper exception handling for completion cases. - + :return: Async generator yielding AxisArray chunks. :rtype: collections.abc.AsyncGenerator """ diff --git a/src/ezmsg/util/messages/key.py b/src/ezmsg/util/messages/key.py index 23e982fe..1850f053 100644 --- a/src/ezmsg/util/messages/key.py +++ b/src/ezmsg/util/messages/key.py @@ -1,6 +1,5 @@ from collections.abc import AsyncGenerator, Generator import traceback -import typing import numpy as np @@ -32,22 +31,24 @@ def set_key(key: str = "") -> Generator[AxisArray, AxisArray, None]: class KeySettings(ez.Settings): """ Settings for key manipulation units. - + Configuration for setting or filtering AxisArray keys. - + :param key: The string to set as the key. :type key: str """ + key: str = "" class SetKey(ez.Unit): """ Unit for setting the key of incoming AxisArray messages. - + Modifies the key field of AxisArray messages while preserving all other data. Uses zero-copy operations for efficient processing. """ + STATE = GenState SETTINGS = KeySettings @@ -58,7 +59,7 @@ class SetKey(ez.Unit): async def initialize(self) -> None: """ Initialize the SetKey unit. - + Sets up the generator for key modification operations. """ self.construct_generator() @@ -66,7 +67,7 @@ async def initialize(self) -> None: def construct_generator(self): """ Construct the key-setting generator with current settings. - + Creates a new set_key generator instance using the unit's key setting. """ self.STATE.gen = set_key(key=self.SETTINGS.key) @@ -75,7 +76,7 @@ def construct_generator(self): async def on_settings(self, msg: ez.Settings) -> None: """ Handle incoming settings updates. - + :param msg: New settings to apply. :type msg: ez.Settings """ @@ -87,10 +88,10 @@ async def on_settings(self, msg: ez.Settings) -> None: async def on_message(self, message: AxisArray) -> AsyncGenerator: """ Process incoming AxisArray messages and set their keys. - + Uses zero-copy operations to efficiently modify the key field while preserving all other data. - + :param message: Input AxisArray to modify. :type message: AxisArray :return: Async generator yielding AxisArray with modified key. @@ -109,10 +110,10 @@ async def on_message(self, message: AxisArray) -> AsyncGenerator: class FilterOnKey(ez.Unit): """ Filter an AxisArray based on its key. - + Only passes through AxisArray messages whose key matches the configured key setting. Uses zero-copy operations for efficient filtering. - + Note: There is no associated generator method for this Unit because messages that fail the filter would still be yielded (as None), which complicates downstream processing. For contexts where filtering on key is desired but the ezmsg framework is not used, use normal Python functional programming. @@ -129,10 +130,10 @@ class FilterOnKey(ez.Unit): async def on_message(self, message: AxisArray) -> AsyncGenerator: """ Filter incoming AxisArray messages based on their key. - + Only yields messages whose key matches the configured filter key. Uses minimal 'touch' to prevent unnecessary deep copying by the framework. - + :param message: Input AxisArray to filter. :type message: AxisArray :return: Async generator yielding filtered AxisArray messages. diff --git a/src/ezmsg/util/messages/modify.py b/src/ezmsg/util/messages/modify.py index 19f784b3..49a9bf9c 100644 --- a/src/ezmsg/util/messages/modify.py +++ b/src/ezmsg/util/messages/modify.py @@ -1,6 +1,5 @@ from collections.abc import AsyncGenerator, Generator import traceback -import typing import numpy as np @@ -60,23 +59,25 @@ def modify_axis( class ModifyAxisSettings(ez.Settings): """ Settings for ModifyAxis unit. - + Configuration for modifying axis names and dimensions of AxisArray messages. :param name_map: A dictionary where the keys are the names of the old dims and the values are the new names. Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised. :type name_map: dict[str, str | None] | None """ + name_map: dict[str, str | None] | None = None class ModifyAxis(ez.Unit): """ Unit for modifying axis names and dimensions of AxisArray messages. - + Renames dimensions and axes according to a name mapping, with support for dropping dimensions. Uses zero-copy operations for efficient processing. """ + STATE = GenState SETTINGS = ModifyAxisSettings @@ -87,7 +88,7 @@ class ModifyAxis(ez.Unit): async def initialize(self) -> None: """ Initialize the ModifyAxis unit. - + Sets up the generator for axis modification operations. """ self.construct_generator() @@ -95,7 +96,7 @@ async def initialize(self) -> None: def construct_generator(self): """ Construct the axis-modifying generator with current settings. - + Creates a new modify_axis generator instance using the unit's name mapping. """ self.STATE.gen = modify_axis(name_map=self.SETTINGS.name_map) @@ -104,7 +105,7 @@ def construct_generator(self): async def on_settings(self, msg: ez.Settings) -> None: """ Handle incoming settings updates. - + :param msg: New settings to apply. :type msg: ez.Settings """ @@ -116,10 +117,10 @@ async def on_settings(self, msg: ez.Settings) -> None: async def on_message(self, message: AxisArray) -> AsyncGenerator: """ Process incoming AxisArray messages and modify their axes. - + Uses zero-copy operations to efficiently modify axis names and dimensions while preserving data integrity. - + :param message: Input AxisArray to modify. :type message: AxisArray :return: Async generator yielding AxisArray with modified axes. diff --git a/src/ezmsg/util/messages/util.py b/src/ezmsg/util/messages/util.py index 72094bde..a632f69e 100644 --- a/src/ezmsg/util/messages/util.py +++ b/src/ezmsg/util/messages/util.py @@ -9,7 +9,7 @@ def fast_replace(arr: typing.Generic[T], **kwargs) -> T: """ Fast replacement of dataclass fields with reduced safety. - + Unlike dataclasses.replace, this function does not check for type compatibility, nor does it check that the passed in fields are valid fields for the dataclass and not flagged as init=False. @@ -18,7 +18,7 @@ def fast_replace(arr: typing.Generic[T], **kwargs) -> T: To force ezmsg to use the legacy replace, set the environment variable: EZMSG_DISABLE_FAST_REPLACE Unset the variable to use this replace function. - + :param arr: The dataclass instance to create a modified copy of. :type arr: typing.Generic[T] :param kwargs: Field values to update in the new instance. diff --git a/src/ezmsg/util/perf/analysis.py b/src/ezmsg/util/perf/analysis.py index 8c0a5cfe..590e681f 100644 --- a/src/ezmsg/util/perf/analysis.py +++ b/src/ezmsg/util/perf/analysis.py @@ -1,5 +1,4 @@ import json -import typing import dataclasses import argparse import html @@ -12,7 +11,7 @@ from .envinfo import TestEnvironmentInfo, format_env_diff from .run import get_datestamp from .impl import ( - TestParameters, + TestParameters, Metrics, TestLogEntry, ) @@ -21,15 +20,15 @@ try: import xarray as xr - import pandas as pd # xarray depends on pandas + import pandas as pd # xarray depends on pandas except ImportError: - ez.logger.error('ezmsg perf analysis requires xarray') + ez.logger.error("ezmsg perf analysis requires xarray") raise try: import numpy as np except ImportError: - ez.logger.error('ezmsg perf analysis requires numpy') + ez.logger.error("ezmsg perf analysis requires numpy") raise TEST_DESCRIPTION = """ @@ -58,15 +57,14 @@ def load_perf(perf: Path) -> xr.Dataset: - all_results: dict[TestParameters, dict[int, list[Metrics]]] = dict() run_idx = 0 - with open(perf, 'r') as perf_f: - info: TestEnvironmentInfo = json.loads(next(perf_f), cls = MessageDecoder) + with open(perf, "r") as perf_f: + info: TestEnvironmentInfo = json.loads(next(perf_f), cls=MessageDecoder) for line in perf_f: - obj = json.loads(line, cls = MessageDecoder) + obj = json.loads(line, cls=MessageDecoder) if isinstance(obj, TestEnvironmentInfo): run_idx += 1 elif isinstance(obj, TestLogEntry): @@ -81,38 +79,47 @@ def load_perf(perf: Path) -> xr.Dataset: comms_axis = list(sorted(set([p.comms for p in all_results.keys()]))) config_axis = list(sorted(set([p.config for p in all_results.keys()]))) - dims = ['n_clients', 'msg_size', 'comms', 'config'] + dims = ["n_clients", "msg_size", "comms", "config"] coords = { - 'n_clients': n_clients_axis, - 'msg_size': msg_size_axis, - 'comms': comms_axis, - 'config': config_axis + "n_clients": n_clients_axis, + "msg_size": msg_size_axis, + "comms": comms_axis, + "config": config_axis, } - + data_vars = {} for field in dataclasses.fields(Metrics): - m = np.zeros(( - len(n_clients_axis), - len(msg_size_axis), - len(comms_axis), - len(config_axis) - )) * np.nan + m = ( + np.zeros( + ( + len(n_clients_axis), + len(msg_size_axis), + len(comms_axis), + len(config_axis), + ) + ) + * np.nan + ) for p, a in all_results.items(): # tests are run multiple times; get the median of means m[ n_clients_axis.index(p.n_clients), msg_size_axis.index(p.msg_size), comms_axis.index(p.comms), - config_axis.index(p.config) - ] = np.median([np.mean([getattr(v, field.name) for v in r]) for r in a.values()]) - data_vars[field.name] = xr.DataArray(m, dims = dims, coords = coords) + config_axis.index(p.config), + ] = np.median( + [np.mean([getattr(v, field.name) for v in r]) for r in a.values()] + ) + data_vars[field.name] = xr.DataArray(m, dims=dims, coords=coords) - dataset = xr.Dataset(data_vars, attrs = dict(info = info)) + dataset = xr.Dataset(data_vars, attrs=dict(info=info)) return dataset + def _escape(s: str) -> str: return html.escape(str(s), quote=True) + def _env_block(title: str, body: str) -> str: return f"""
@@ -121,6 +128,7 @@ def _env_block(title: str, body: str) -> str:
""" + def _legend_block() -> str: return """
@@ -133,6 +141,7 @@ def _legend_block() -> str:
""" + def _base_css() -> str: # Minimal, print-friendly CSS + color scales for cells. return """ @@ -237,7 +246,10 @@ def _base_css() -> str: """ -def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 10.0) -> str: + +def _color_for_comparison( + value: float, metric: str, noise_band_pct: float = 10.0 +) -> str: """ Returns inline CSS background for a comparison % value. value: e.g., 97.3, 104.8, etc. @@ -250,11 +262,11 @@ def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 10. delta = value - 100.0 # Determine direction: + is good for sample/data; - is good for latency - if 'rate' in metric: + if "rate" in metric: # positive delta good, negative bad magnitude = abs(delta) sign_good = delta > 0 - elif 'latency' in metric: + elif "latency" in metric: # negative delta good (lower latency) magnitude = abs(delta) sign_good = delta < 0 @@ -275,6 +287,7 @@ def _color_for_comparison(value: float, metric: str, noise_band_pct: float = 10. alpha = 0.15 + 0.35 * scale # 0.15..0.50 return f"background-color: hsla({hue}, 70%, 45%, {alpha});" + def _format_number(x) -> str: if isinstance(x, (int,)) and not isinstance(x, bool): return f"{x:d}" @@ -287,13 +300,13 @@ def _format_number(x) -> str: def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> None: - """ print perf test results and comparisons to the console """ + """print perf test results and comparisons to the console""" - output = '' + output = "" perf = load_perf(perf_path) - info: TestEnvironmentInfo = perf.attrs['info'] - output += str(info) + '\n\n' + info: TestEnvironmentInfo = perf.attrs["info"] + output += str(info) + "\n\n" relative = False env_diff = None @@ -302,42 +315,52 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> output += "PERFORMANCE COMPARISON\n\n" baseline = load_perf(baseline_path) perf = (perf / baseline) * 100.0 - baseline_info: TestEnvironmentInfo = baseline.attrs['info'] + baseline_info: TestEnvironmentInfo = baseline.attrs["info"] env_diff = format_env_diff(info.diff(baseline_info)) - output += env_diff + '\n\n' + output += env_diff + "\n\n" - # These raw stats are still valuable to have, but are confusing + # These raw stats are still valuable to have, but are confusing # when making relative comparisons - perf = perf.drop_vars(['latency_total', 'num_msgs']) + perf = perf.drop_vars(["latency_total", "num_msgs"]) - perf = perf.stack(params = ['n_clients', 'msg_size']).dropna('params') + perf = perf.stack(params=["n_clients", "msg_size"]).dropna("params") df = perf.squeeze().to_dataframe() - df = df.drop('n_clients', axis = 1) - df = df.drop('msg_size', axis = 1) + df = df.drop("n_clients", axis=1) + df = df.drop("msg_size", axis=1) - for _, config_ds in perf.groupby('config'): - for _, comms_ds in config_ds.groupby('comms'): - output += str(comms_ds.squeeze().to_dataframe()) + '\n\n' - output += '\n' + for _, config_ds in perf.groupby("config"): + for _, comms_ds in config_ds.groupby("comms"): + output += str(comms_ds.squeeze().to_dataframe()) + "\n\n" + output += "\n" print(output) if html: # Ensure expected columns exist - expected_cols = {"sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"} + expected_cols = { + "sample_rate_mean", + "sample_rate_median", + "data_rate", + "latency_mean", + "latency_median", + } missing = expected_cols - set(df.columns) if missing: raise ValueError(f"Missing expected columns in dataset: {missing}") # We'll render a table per (config, comms) group. - groups = df.reset_index().sort_values( - by=["config", "comms", "n_clients", "msg_size"] - ).groupby(["config", "comms"], sort=False) + groups = ( + df.reset_index() + .sort_values(by=["config", "comms", "n_clients", "msg_size"]) + .groupby(["config", "comms"], sort=False) + ) # Build HTML parts: list[str] = [] parts.append("") - parts.append("") + parts.append( + "" + ) parts.append("ezmsg perf report") parts.append(_base_css()) parts.append("
") @@ -346,7 +369,7 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> parts.append("

ezmsg Performance Report

") sub = str(perf_path) if baseline_path is not None: - sub += f' relative to {str(baseline_path)}' + sub += f" relative to {str(baseline_path)}" parts.append(f"
{_escape(sub)}
") parts.append("") @@ -366,11 +389,21 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> # Render each group for (config, comms), g in groups: # Keep only expected columns in order - cols = ["n_clients", "msg_size", "sample_rate_mean", "sample_rate_median", "data_rate", "latency_mean", "latency_median"] + cols = [ + "n_clients", + "msg_size", + "sample_rate_mean", + "sample_rate_median", + "data_rate", + "latency_mean", + "latency_median", + ] g = g[cols].copy() # String format some columns (msg_size with separators) - g["msg_size"] = g["msg_size"].map(lambda x: f"{int(x):,}" if pd.notna(x) else x) + g["msg_size"] = g["msg_size"].map( + lambda x: f"{int(x):,}" if pd.notna(x) else x + ) # Build table manually so we can inject inline cell styles easily # (pandas Styler is great but produces bulky HTML; manual keeps it clean) @@ -378,26 +411,38 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> n_clients - msg_size {'' if relative else '(b)'} - sample_rate_mean {'' if relative else '(msgs/s)'} - sample_rate_median {'' if relative else '(msgs/s)'} - data_rate {'' if relative else '(MB/s)'} - latency_mean {'' if relative else '(us)'} - latency_median {'' if relative else '(us)'} + msg_size {"" if relative else "(b)"} + sample_rate_mean {"" if relative else "(msgs/s)"} + sample_rate_median {"" if relative else "(msgs/s)"} + data_rate {"" if relative else "(MB/s)"} + latency_mean {"" if relative else "(us)"} + latency_median {"" if relative else "(us)"} """ body_rows: list[str] = [] for _, row in g.iterrows(): - sr, srm, dr, lt, lm = row["sample_rate_mean"], row["sample_rate_median"], row["data_rate"], row["latency_mean"], row["latency_median"] + sr, srm, dr, lt, lm = ( + row["sample_rate_mean"], + row["sample_rate_median"], + row["data_rate"], + row["latency_mean"], + row["latency_median"], + ) dr = dr if relative else dr / 2**20 lt = lt if relative else lt * 1e6 lm = lm if relative else lm * 1e6 - sr_style = _color_for_comparison(sr, "sample_rate_mean") if relative else "" - srm_style = _color_for_comparison(srm, "sample_rate_median") if relative else "" + sr_style = ( + _color_for_comparison(sr, "sample_rate_mean") if relative else "" + ) + srm_style = ( + _color_for_comparison(srm, "sample_rate_median") if relative else "" + ) dr_style = _color_for_comparison(dr, "data_rate") if relative else "" lt_style = _color_for_comparison(lt, "latency_mean") if relative else "" - lm_style = _color_for_comparison(lm, "latency_median") if relative else "" + lm_style = ( + _color_for_comparison(lm, "latency_median") if relative else "" + ) body_rows.append( "" @@ -422,13 +467,13 @@ def summary(perf_path: Path, baseline_path: Path | None, html: bool = False) -> parts.append("
") html_text = "".join(parts) - out_path = Path(f'report_{get_datestamp()}.html') + out_path = Path(f"report_{get_datestamp()}.html") out_path.write_text(html_text, encoding="utf-8") webbrowser.open(out_path.resolve().as_uri()) def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: - p_summary = subparsers.add_parser("summary", help = "summarize performance results") + p_summary = subparsers.add_parser("summary", help="summarize performance results") p_summary.add_argument( "perf", type=Path, @@ -443,12 +488,12 @@ def setup_summary_cmdline(subparsers: argparse._SubParsersAction) -> None: ) p_summary.add_argument( "--html", - action = 'store_true', - help = "generate an html output file and render results in browser", + action="store_true", + help="generate an html output file and render results in browser", ) - p_summary.set_defaults(_handler=lambda ns: summary( - perf_path = ns.perf, - baseline_path = ns.baseline, - html = ns.html - )) \ No newline at end of file + p_summary.set_defaults( + _handler=lambda ns: summary( + perf_path=ns.perf, baseline_path=ns.baseline, html=ns.html + ) + ) diff --git a/src/ezmsg/util/perf/command.py b/src/ezmsg/util/perf/command.py index ab919999..21fed7eb 100644 --- a/src/ezmsg/util/perf/command.py +++ b/src/ezmsg/util/perf/command.py @@ -3,15 +3,17 @@ from .analysis import setup_summary_cmdline from .run import setup_run_cmdline + def command() -> None: - parser = argparse.ArgumentParser(description = 'ezmsg perf test utility') + parser = argparse.ArgumentParser(description="ezmsg perf test utility") subparsers = parser.add_subparsers(dest="command", required=True) setup_run_cmdline(subparsers) setup_summary_cmdline(subparsers) - + ns = parser.parse_args() ns._handler(ns) + if __name__ == "__main__": command() diff --git a/src/ezmsg/util/perf/envinfo.py b/src/ezmsg/util/perf/envinfo.py index 7320d5fd..654454f0 100644 --- a/src/ezmsg/util/perf/envinfo.py +++ b/src/ezmsg/util/perf/envinfo.py @@ -22,33 +22,53 @@ def _git_commit() -> str: try: - return subprocess.check_output( - ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL - ).decode().strip() - except: + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + except (subprocess.CalledProcessError, FileNotFoundError): return "unknown" + def _git_branch() -> str: try: - return subprocess.check_output( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL - ).decode().strip() - except: + return ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL + ) + .decode() + .strip() + ) + except (subprocess.CalledProcessError, FileNotFoundError): return "unknown" + @dataclasses.dataclass class TestEnvironmentInfo: ezmsg_version: str = dataclasses.field(default_factory=lambda: ez.__version__) numpy_version: str = dataclasses.field(default_factory=lambda: np.__version__) - python_version: str = dataclasses.field(default_factory=lambda: sys.version.replace("\n", " ")) + python_version: str = dataclasses.field( + default_factory=lambda: sys.version.replace("\n", " ") + ) os: str = dataclasses.field(default_factory=lambda: platform.system()) os_version: str = dataclasses.field(default_factory=lambda: platform.version()) machine: str = dataclasses.field(default_factory=lambda: platform.machine()) processor: str = dataclasses.field(default_factory=lambda: platform.processor()) - cpu_count_logical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=True)) - cpu_count_physical: int | None = dataclasses.field(default_factory=lambda: psutil.cpu_count(logical=False)) - memory_gb: float = dataclasses.field(default_factory=lambda:round(psutil.virtual_memory().total / (1024**3), 2)) - start_time: str = dataclasses.field(default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds")) + cpu_count_logical: int | None = dataclasses.field( + default_factory=lambda: psutil.cpu_count(logical=True) + ) + cpu_count_physical: int | None = dataclasses.field( + default_factory=lambda: psutil.cpu_count(logical=False) + ) + memory_gb: float = dataclasses.field( + default_factory=lambda: round(psutil.virtual_memory().total / (1024**3), 2) + ) + start_time: str = dataclasses.field( + default_factory=lambda: datetime.datetime.now().isoformat(timespec="seconds") + ) git_commit: str = dataclasses.field(default_factory=_git_commit) git_branch: str = dataclasses.field(default_factory=_git_branch) @@ -59,8 +79,10 @@ def __str__(self) -> str: for key, value in fields.items(): lines.append(f" {key.ljust(width)} : {value}") return "\n".join(lines) - - def diff(self, other: "TestEnvironmentInfo") -> dict[str, tuple[typing.Any, typing.Any]]: + + def diff( + self, other: "TestEnvironmentInfo" + ) -> dict[str, tuple[typing.Any, typing.Any]]: """Return a structured diff: {field: (self_value, other_value)} for changed fields.""" a = dataclasses.asdict(self) b = dataclasses.asdict(other) @@ -79,5 +101,6 @@ def format_env_diff(diffs: dict[str, tuple[typing.Any, typing.Any]]) -> str: lines.append(f" {k.ljust(width)} : {left} != {right}") return "\n".join(lines) + def diff_envs(a: TestEnvironmentInfo, b: TestEnvironmentInfo) -> str: - return format_env_diff(a.diff(b)) \ No newline at end of file + return format_env_diff(a.diff(b)) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index fc46b81e..4ab6d7f2 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -25,19 +25,23 @@ def collect( process_components: typing.Collection[ez.Component] | None = None, **components_kwargs: ez.Component, ) -> ez.Collection: - """ collect a grouping of pre-configured components into a new "Collection" """ + """collect a grouping of pre-configured components into a new "Collection" """ from ezmsg.core.util import either_dict_or_kwargs - + components = either_dict_or_kwargs(components, components_kwargs, "collect") if components is None: raise ValueError("Must supply at least one component to run") - + out = ez.Collection() for name, comp in components.items(): comp._set_name(name) - out._components = components # FIXME: Component._components should be typehinted as a Mapping + out._components = ( + components # FIXME: Component._components should be typehinted as a Mapping + ) out.network = lambda: network - out.process_components = lambda: (out,) if process_components is None else process_components + out.process_components = ( + lambda: (out,) if process_components is None else process_components + ) return out @@ -67,9 +71,11 @@ class LoadTestSample: dynamic_data: np.ndarray key: str + class LoadTestSourceState(ez.State): counter: int = 0 + class LoadTestSource(ez.Unit): OUTPUT = ez.OutputStream(LoadTestSample) SETTINGS = LoadTestSettings @@ -84,7 +90,6 @@ async def publish(self) -> typing.AsyncGenerator: ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") start_time = time.time() for _ in range(self.SETTINGS.num_msgs): - current_time = time.time() if current_time - start_time >= self.SETTINGS.max_duration: break @@ -97,24 +102,23 @@ async def publish(self) -> typing.AsyncGenerator: dynamic_data=np.zeros( int(self.SETTINGS.dynamic_size // 4), dtype=np.float32 ), - key = self.name, + key=self.name, ), ) self.STATE.counter += 1 - + ez.logger.info("Exiting publish") raise ez.Complete - + async def shutdown(self) -> None: ez.logger.info(f"Samples sent: {self.STATE.counter}") - class LoadTestRelay(ez.Unit): INPUT = ez.InputStream(LoadTestSample) OUTPUT = ez.OutputStream(LoadTestSample) - @ez.subscriber(INPUT, zero_copy = True) + @ez.subscriber(INPUT, zero_copy=True) @ez.publisher(OUTPUT) async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator: yield self.OUTPUT, msg @@ -140,9 +144,7 @@ async def initialize(self) -> None: async def receive(self, sample: LoadTestSample) -> None: counter = self.STATE.counters.get(sample.key, -1) if sample.counter != counter + 1: - ez.logger.warning( - f"{sample.counter - counter - 1} samples skipped!" - ) + ez.logger.warning(f"{sample.counter - counter - 1} samples skipped!") self.STATE.received_data.append( (sample._timestamp, time.time(), sample.counter) ) @@ -150,7 +152,6 @@ async def receive(self, sample: LoadTestSample) -> None: class LoadTestSink(LoadTestReceiver): - INPUT = ez.InputStream(LoadTestSample) @ez.subscriber(INPUT, zero_copy=True) @@ -169,6 +170,7 @@ async def terminate(self) -> None: ### TEST CONFIGURATIONS + @dataclasses.dataclass class ConfigSettings: n_clients: int @@ -176,11 +178,13 @@ class ConfigSettings: source: LoadTestSource sink: LoadTestSink + Configuration = typing.Tuple[typing.Iterable[ez.Component], ez.NetworkDefinition] Configurator = typing.Callable[[ConfigSettings], Configuration] + def fanout(config: ConfigSettings) -> Configuration: - """ one pub to many subs """ + """one pub to many subs""" connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] subs = [LoadTestReceiver(config.settings) for _ in range(config.n_clients)] for sub in subs: @@ -188,19 +192,20 @@ def fanout(config: ConfigSettings) -> Configuration: return subs, connections + def fanin(config: ConfigSettings) -> Configuration: - """ many pubs to one sub """ + """many pubs to one sub""" connections: ez.NetworkDefinition = [(config.source.OUTPUT, config.sink.INPUT)] pubs = [LoadTestSource(config.settings) for _ in range(config.n_clients)] expected_num_msgs = config.sink.SETTINGS.num_msgs * len(pubs) - config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs = expected_num_msgs) # type: ignore + config.sink.SETTINGS = replace(config.sink.SETTINGS, num_msgs=expected_num_msgs) # type: ignore for pub in pubs: connections.append((pub.OUTPUT, config.sink.INPUT)) return pubs, connections def relay(config: ConfigSettings) -> Configuration: - """ one pub to one sub through many relays """ + """one pub to one sub through many relays""" connections: ez.NetworkDefinition = [] relays = [LoadTestRelay(config.settings) for _ in range(config.n_clients)] @@ -209,49 +214,48 @@ def relay(config: ConfigSettings) -> Configuration: for from_relay, to_relay in zip(relays[:-1], relays[1:]): connections.append((from_relay.OUTPUT, to_relay.INPUT)) connections.append((relays[-1].OUTPUT, config.sink.INPUT)) - else: connections.append((config.source.OUTPUT, config.sink.INPUT)) + else: + connections.append((config.source.OUTPUT, config.sink.INPUT)) return relays, connections + CONFIGS: typing.Mapping[str, Configurator] = { - c.__name__: c for c in [ - fanin, - fanout, - relay - ] + c.__name__: c for c in [fanin, fanout, relay] } + class Communication(enum.StrEnum): LOCAL = "local" SHM = "shm" SHM_SPREAD = "shm_spread" TCP = "tcp" TCP_SPREAD = "tcp_spread" - + + def perform_test( n_clients: int, - max_duration: float, + max_duration: float, num_msgs: int, - msg_size: int, + msg_size: int, buffers: int, comms: Communication, config: Configurator, - graph_address: Address + graph_address: Address, ) -> Metrics: - settings = LoadTestSettings( - dynamic_size = int(msg_size), - num_msgs = num_msgs, - max_duration = max_duration, - buffers = buffers, - force_tcp = (comms in (Communication.TCP, Communication.TCP_SPREAD)), + dynamic_size=int(msg_size), + num_msgs=num_msgs, + max_duration=max_duration, + buffers=buffers, + force_tcp=(comms in (Communication.TCP, Communication.TCP_SPREAD)), ) source = LoadTestSource(settings) sink = LoadTestSink(settings) - + components: typing.Mapping[str, ez.Component] = dict( - SINK = sink, + SINK=sink, ) clients, connections = config(ConfigSettings(n_clients, settings, source, sink)) @@ -262,16 +266,15 @@ def perform_test( # Every component in the same process (this one) components["SOURCE"] = source for i, client in enumerate(clients): - components[f"CLIENT_{i+1}"] = client + components[f"CLIENT_{i + 1}"] = client else: - if comms in (Communication.SHM_SPREAD, Communication.TCP_SPREAD): # Every component in its own process. components["SOURCE"] = source process_components.append(source) for i, client in enumerate(clients): - components[f'CLIENT_{i+1}'] = client + components[f"CLIENT_{i + 1}"] = client process_components.append(client) else: @@ -279,24 +282,23 @@ def perform_test( collect_comps: typing.Mapping[str, ez.Component] = dict() collect_comps["SOURCE"] = source for i, client in enumerate(clients): - collect_comps[f"CLIENT_{i+1}"] = client - proc_collection = collect(components = collect_comps) + collect_comps[f"CLIENT_{i + 1}"] = client + proc_collection = collect(components=collect_comps) components["PROC"] = proc_collection process_components = [proc_collection] with stable_perf(): ez.run( - components = components, - connections = connections, - process_components = process_components, - graph_address = graph_address + components=components, + connections=connections, + process_components=process_components, + graph_address=graph_address, ) return calculate_metrics(sink) def calculate_metrics(sink: LoadTestSink) -> Metrics: - # Log some useful summary statistics min_timestamp = min(timestamp for timestamp, _, _ in sink.STATE.received_data) max_timestamp = max(timestamp for timestamp, _, _ in sink.STATE.received_data) @@ -335,13 +337,13 @@ def calculate_metrics(sink: LoadTestSink) -> Metrics: ) return Metrics( - num_msgs = num_samples, - sample_rate_mean = samplerate_mean, - sample_rate_median = samplerate_median, - latency_mean = latency_mean, - latency_median = latency_median, - latency_total = total_latency, - data_rate = data_rate + num_msgs=num_samples, + sample_rate_mean=samplerate_mean, + sample_rate_median=samplerate_median, + latency_mean=latency_mean, + latency_median=latency_median, + latency_total=total_latency, + data_rate=data_rate, ) @@ -359,4 +361,4 @@ class TestParameters: @dataclasses.dataclass class TestLogEntry: params: TestParameters - results: Metrics \ No newline at end of file + results: Metrics diff --git a/src/ezmsg/util/perf/run.py b/src/ezmsg/util/perf/run.py index abb49a16..3f794f21 100644 --- a/src/ezmsg/util/perf/run.py +++ b/src/ezmsg/util/perf/run.py @@ -1,7 +1,6 @@ import os import sys import json -import datetime import itertools import argparse import typing @@ -19,8 +18,8 @@ from .util import warmup from .impl import ( TestParameters, - TestLogEntry, - perform_test, + TestLogEntry, + perform_test, Communication, CONFIGS, ) @@ -29,36 +28,43 @@ DEFAULT_N_CLIENTS = [1, 16] DEFAULT_COMMS = [c for c in Communication] + # --- Output Suppression Context Manager --- @contextmanager def suppress_output(verbose: bool = False): """Context manager to redirect stdout and stderr to os.devnull""" - if verbose: + if verbose: yield else: # Open the null device for writing - with open(os.devnull, 'w') as fnull: + with open(os.devnull, "w") as fnull: # Redirect both stdout and stderr to the null device with redirect_stderr(fnull): with redirect_stdout(fnull): yield -CHECK_FOR_QUIT = lambda: False -if sys.platform.startswith('win'): +def _check_for_quit_default() -> bool: + return False + + +CHECK_FOR_QUIT = _check_for_quit_default + +if sys.platform.startswith("win"): import msvcrt + def _check_for_quit_win() -> bool: """ Checks for the 'q' key press in a non-blocking way. Returns True if 'q' is pressed (case-insensitive), False otherwise. """ # Windows: Use msvcrt for non-blocking keyboard hit detection - if msvcrt.kbhit(): # type: ignore + if msvcrt.kbhit(): # type: ignore # Read the key press (returns bytes) - key = msvcrt.getch() # type: ignore + key = msvcrt.getch() # type: ignore try: # Decode and check for 'q' - return key.decode().lower() == 'q' + return key.decode().lower() == "q" except UnicodeDecodeError: # Handle potential non-text key presses gracefully return False @@ -68,6 +74,7 @@ def _check_for_quit_win() -> bool: else: import select + def _check_for_quit() -> bool: """ Checks for the 'q' key press in a non-blocking way. @@ -77,19 +84,21 @@ def _check_for_quit() -> bool: # select.select(rlist, wlist, xlist, timeout) # timeout=0 makes it non-blocking if sys.stdin.isatty(): - i, o, e = select.select([sys.stdin], [], [], 0) # type: ignore + i, o, e = select.select([sys.stdin], [], [], 0) # type: ignore if i: # Read the buffered character key = sys.stdin.read(1) - return key.lower() == 'q' + return key.lower() == "q" return False - + CHECK_FOR_QUIT = _check_for_quit + def get_datestamp() -> str: return datetime.now().strftime("%Y%m%d_%H%M%S") -def perf_run( + +def perf_run( max_duration: float, num_msgs: int, num_buffers: int, @@ -102,36 +111,43 @@ def perf_run( grid: bool, warmup_dur: float, ) -> None: - if n_clients is None: n_clients = DEFAULT_N_CLIENTS if any(c < 0 for c in n_clients): - ez.logger.error('All tests must have >=0 clients') + ez.logger.error("All tests must have >=0 clients") return if msg_sizes is None: msg_sizes = DEFAULT_MSG_SIZES if any(s < 0 for s in msg_sizes): - ez.logger.error('All msg_sizes must be >=0 bytes') + ez.logger.error("All msg_sizes must be >=0 bytes") if not grid and len(list(n_clients)) != len(list(msg_sizes)): ez.logger.warning( - "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + \ - f"{len(n_clients)=} which is not equal to {len(msg_sizes)=}. " + "Not performing a grid test of all combinations of n_clients and msg_sizes, but " + + f"{len(n_clients)=} which is not equal to {len(msg_sizes)=}. " ) try: - communications = DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] + communications = ( + DEFAULT_COMMS if comms is None else [Communication(c) for c in comms] + ) except ValueError: - ez.logger.error(f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}") + ez.logger.error( + f"Invalid test communications requested. Valid communications: {', '.join([c.value for c in Communication])}" + ) return - + try: - configurators = list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs] + configurators = ( + list(CONFIGS.values()) if configs is None else [CONFIGS[c] for c in configs] + ) except ValueError: - ez.logger.error(f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}") + ez.logger.error( + f"Invalid test configuration requested. Valid configurations: {', '.join([c for c in CONFIGS])}" + ) return - + subitr = itertools.product if grid else zip test_list = [ @@ -145,10 +161,18 @@ def perf_run( server = GraphServer() server.start() - ez.logger.info(f"About to run {len(test_list)} tests (repeated {repeats} times) of {max_duration} sec (max) each.") - ez.logger.info(f"During each test, source will attempt to send {num_msgs} messages to the sink.") - ez.logger.info(f"Please try to avoid running other taxing software while this perf test runs.") - ez.logger.info(f"NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early.") + ez.logger.info( + f"About to run {len(test_list)} tests (repeated {repeats} times) of {max_duration} sec (max) each." + ) + ez.logger.info( + f"During each test, source will attempt to send {num_msgs} messages to the sink." + ) + ez.logger.info( + "Please try to avoid running other taxing software while this perf test runs." + ) + ez.logger.info( + "NOTE: Tests swallow interrupt. After warmup, use 'q' then [enter] to quit tests early." + ) quitting = False @@ -158,61 +182,61 @@ def perf_run( ez.logger.info(f"Warming up for {warmup_dur} seconds...") warmup(warmup_dur) - with open(f'perf_{get_datestamp()}.txt', 'w') as out_f: + with open(f"perf_{get_datestamp()}.txt", "w") as out_f: for _ in range(repeats): - out_f.write(json.dumps(TestEnvironmentInfo(), cls = MessageEncoder) + "\n") + out_f.write( + json.dumps(TestEnvironmentInfo(), cls=MessageEncoder) + "\n" + ) for test_idx, (msg_size, clients, conf, comm) in enumerate(test_list): - if CHECK_FOR_QUIT(): ez.logger.info("Stopping tests early...") quitting = True break ez.logger.info( - f"TEST {test_idx + 1}/{len(test_list)}: " \ - f"{clients=}, {msg_size=}, conf={conf.__name__}, " \ + f"TEST {test_idx + 1}/{len(test_list)}: " + f"{clients=}, {msg_size=}, conf={conf.__name__}, " f"comm={comm.value}" ) output = TestLogEntry( - params = TestParameters( - msg_size = msg_size, - num_msgs = num_msgs, - n_clients = clients, - config = conf.__name__, - comms = comm.value, - max_duration = max_duration, - num_buffers = num_buffers + params=TestParameters( + msg_size=msg_size, + num_msgs=num_msgs, + n_clients=clients, + config=conf.__name__, + comms=comm.value, + max_duration=max_duration, + num_buffers=num_buffers, + ), + results=perform_test( + n_clients=clients, + max_duration=max_duration, + num_msgs=num_msgs, + msg_size=msg_size, + buffers=num_buffers, + comms=comm, + config=conf, + graph_address=server.address, ), - results = perform_test( - n_clients = clients, - max_duration = max_duration, - num_msgs = num_msgs, - msg_size = msg_size, - buffers = num_buffers, - comms = comm, - config = conf, - graph_address = server.address - ) ) - out_f.write(json.dumps(output, cls = MessageEncoder) + "\n") + out_f.write(json.dumps(output, cls=MessageEncoder) + "\n") if quitting: break finally: server.stop() - d = datetime(1,1,1) + timedelta(seconds = time.time() - start_time) - dur_str = ':'.join([str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0]) + d = datetime(1, 1, 1) + timedelta(seconds=time.time() - start_time) + dur_str = ":".join( + [str(n) for n in [d.day - 1, d.hour, d.minute, d.second] if n != 0] + ) ez.logger.info(f"Tests concluded. Wallclock Runtime: {dur_str}s") - - def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: - p_run = subparsers.add_parser("run", help="run performance test") p_run.add_argument( @@ -226,21 +250,21 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: "--num-msgs", type=int, default=1000, - help = "number of messages to send per-test (default = 1000)" + help="number of messages to send per-test (default = 1000)", ) - # NOTE: We default num-buffers = 1 because this degenerate perf test scenario (blasting - # messages as fast as possible through the system) results in one of two scenerios: + # NOTE: We default num-buffers = 1 because this degenerate perf test scenario (blasting + # messages as fast as possible through the system) results in one of two scenerios: # 1. A (few) messages is/are enqueued and dequeued before another message is posted # 2. The buffer fills up before being FULLY emptied resulting in longer latency. # (once a channel enters this condition, it tends to stay in this condition) # # This _indeterminate_ behavior results in bimodal distributions of runtimes that make - # A/B performance comparisons difficult. The perf test is not representative of the vast - # majority of production ezmsg systems where publishing is generally rate-limited. + # A/B performance comparisons difficult. The perf test is not representative of the vast + # majority of production ezmsg systems where publishing is generally rate-limited. # - # A flow-control algorithm could stabilize perf-test results with num_buffers > 1, but is - # generally implemented by enforcing delays on the publish side which simply degrades + # A flow-control algorithm could stabilize perf-test results with num_buffers > 1, but is + # generally implemented by enforcing delays on the publish side which simply degrades # performance in the vast majority of ezmsg systems. - Griff p_run.add_argument( "--num-buffers", @@ -250,68 +274,72 @@ def setup_run_cmdline(subparsers: argparse._SubParsersAction) -> None: ) p_run.add_argument( - "--iters", "-i", - type = int, - default = 5, - help = "number of times to run each test (default = 5)" + "--iters", + "-i", + type=int, + default=5, + help="number of times to run each test (default = 5)", ) p_run.add_argument( - "--repeats", "-r", - type = int, - default = 10, - help = "number of times to repeat the perf (default = 10)" + "--repeats", + "-r", + type=int, + default=10, + help="number of times to repeat the perf (default = 10)", ) p_run.add_argument( "--msg-sizes", - type = int, - default = None, - nargs = "*", - help = f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})" + type=int, + default=None, + nargs="*", + help=f"message sizes in bytes (default = {DEFAULT_MSG_SIZES})", ) p_run.add_argument( "--n-clients", - type = int, - default = None, - nargs = "*", - help = f"number of clients (default = {DEFAULT_N_CLIENTS})" + type=int, + default=None, + nargs="*", + help=f"number of clients (default = {DEFAULT_N_CLIENTS})", ) p_run.add_argument( "--comms", - type = str, - default = None, - nargs = "*", - help = f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})" + type=str, + default=None, + nargs="*", + help=f"communication strategies to test (default = {[c.value for c in DEFAULT_COMMS]})", ) p_run.add_argument( "--configs", - type = str, - default = None, - nargs = "*", - help = f"configurations to test (default = {[c for c in CONFIGS]})" + type=str, + default=None, + nargs="*", + help=f"configurations to test (default = {[c for c in CONFIGS]})", ) p_run.add_argument( "--warmup", - type = float, - default = 60.0, - help = "warmup CPU with busy task for some number of seconds (default = 60.0)" + type=float, + default=60.0, + help="warmup CPU with busy task for some number of seconds (default = 60.0)", ) - p_run.set_defaults(_handler=lambda ns: perf_run( - max_duration = ns.max_duration, - num_msgs = ns.num_msgs, - num_buffers = ns.num_buffers, - iters = ns.iters, - repeats = ns.repeats, - msg_sizes = ns.msg_sizes, - n_clients = ns.n_clients, - comms = ns.comms, - configs = ns.configs, - grid = True, - warmup_dur = ns.warmup, - )) \ No newline at end of file + p_run.set_defaults( + _handler=lambda ns: perf_run( + max_duration=ns.max_duration, + num_msgs=ns.num_msgs, + num_buffers=ns.num_buffers, + iters=ns.iters, + repeats=ns.repeats, + msg_sizes=ns.msg_sizes, + n_clients=ns.n_clients, + comms=ns.comms, + configs=ns.configs, + grid=True, + warmup_dur=ns.warmup, + ) + ) diff --git a/src/ezmsg/util/perf/util.py b/src/ezmsg/util/perf/util.py index c17e7fe5..eb79562f 100644 --- a/src/ezmsg/util/perf/util.py +++ b/src/ezmsg/util/perf/util.py @@ -5,7 +5,6 @@ import statistics as stats import contextlib import subprocess -import platform from dataclasses import dataclass from typing import Iterable @@ -20,6 +19,7 @@ # ---------- Utilities ---------- + def _set_env_threads(single_thread: bool = True): """ Normalize math/threading libs so they don't spawn surprise worker threads. @@ -33,8 +33,10 @@ def _set_env_threads(single_thread: bool = True): # Keep PYTHONHASHSEED stable for deterministic dict/set iteration costs os.environ.setdefault("PYTHONHASHSEED", "0") + # ---------- Priority & Affinity ---------- + @contextlib.contextmanager def _process_priority(): """ @@ -48,13 +50,18 @@ def _process_priority(): orig_nice = None if _IS_WIN: try: - import ctypes, ctypes.wintypes as wt + import ctypes + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) ABOVE_NORMAL_PRIORITY_CLASS = 0x00008000 HIGH_PRIORITY_CLASS = 0x00000080 # Try High, fall back to Above Normal - if not kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), HIGH_PRIORITY_CLASS): - kernel32.SetPriorityClass(kernel32.GetCurrentProcess(), ABOVE_NORMAL_PRIORITY_CLASS) + if not kernel32.SetPriorityClass( + kernel32.GetCurrentProcess(), HIGH_PRIORITY_CLASS + ): + kernel32.SetPriorityClass( + kernel32.GetCurrentProcess(), ABOVE_NORMAL_PRIORITY_CLASS + ) except Exception: pass else: @@ -79,6 +86,7 @@ def _process_priority(): except Exception: pass + @contextlib.contextmanager def _cpu_affinity(prefer_isolation: bool = True): """ @@ -98,7 +106,7 @@ def _cpu_affinity(prefer_isolation: bool = True): if prefer_isolation and len(cpus) > 2: # Pick two middle CPUs to avoid 0 which often handles interrupts mid = len(cpus) // 2 - cpus = [cpus[mid-1], cpus[mid]] + cpus = [cpus[mid - 1], cpus[mid]] p.cpu_affinity(cpus) yield finally: @@ -108,8 +116,10 @@ def _cpu_affinity(prefer_isolation: bool = True): except Exception: pass + # ---------- Platform-specific helpers ---------- + @contextlib.contextmanager def _mac_caffeinate(): """ @@ -132,6 +142,7 @@ def _mac_caffeinate(): except Exception: pass + @contextlib.contextmanager def _win_timer_resolution(ms: int = 1): """ @@ -141,6 +152,7 @@ def _win_timer_resolution(ms: int = 1): yield return import ctypes + winmm = ctypes.WinDLL("winmm") timeBeginPeriod = winmm.timeBeginPeriod timeEndPeriod = winmm.timeEndPeriod @@ -156,8 +168,10 @@ def _win_timer_resolution(ms: int = 1): except Exception: pass + # ---------- Warm-up & GC ---------- + def warmup(seconds: float = 60.0, fn=None, *args, **kwargs): """ Optional warm-up to reach steady clocks/caches. @@ -178,6 +192,7 @@ def warmup(seconds: float = 60.0, fn=None, *args, **kwargs): while time.perf_counter() < target: fn(*args, **kwargs) + @contextlib.contextmanager def gc_pause(): """ @@ -192,8 +207,10 @@ def gc_pause(): gc.enable() gc.collect() + # ---------- Robust statistics ---------- + def median_of_means(samples: Iterable[float], k: int = 5) -> float: """ Robust estimate: split samples into k buckets (round-robin), average each, take median of bucket means. @@ -205,22 +222,25 @@ def median_of_means(samples: Iterable[float], k: int = 5) -> float: buckets = [[] for _ in range(k)] for i, v in enumerate(samples): buckets[i % k].append(v) - means = [sum(b)/len(b) for b in buckets if b] + means = [sum(b) / len(b) for b in buckets if b] means.sort() - return means[len(means)//2] + return means[len(means) // 2] + def coef_var(samples: Iterable[float]) -> float: vals = list(samples) if len(vals) < 2: return 0.0 - m = sum(vals)/len(vals) + m = sum(vals) / len(vals) if m == 0: return 0.0 sd = stats.pstdev(vals) return sd / m + # ---------- Public context manager ---------- + @dataclass class PerfOptions: single_thread_math: bool = True @@ -230,6 +250,7 @@ class PerfOptions: tweak_timer_windows: bool = True keep_mac_awake: bool = True + @contextlib.contextmanager def stable_perf(opts: PerfOptions = PerfOptions()): """ diff --git a/src/ezmsg/util/rate.py b/src/ezmsg/util/rate.py index 91452eba..e8e51afb 100644 --- a/src/ezmsg/util/rate.py +++ b/src/ezmsg/util/rate.py @@ -30,7 +30,7 @@ def __init__(self, hz: float): def _remaining(self, curr_time: float) -> float: """ Calculate the time remaining for rate to sleep. - + :param curr_time: Current time :type curr_time: float :return: Time remaining in seconds @@ -47,7 +47,7 @@ def _remaining(self, curr_time: float) -> float: def remaining(self) -> float: """ Return the time remaining for rate to sleep. - + :return: Time remaining in seconds :rtype: float """ @@ -57,7 +57,7 @@ def remaining(self) -> float: def _sleep_logic(self) -> float: """ Internal method to calculate sleep time and update timing state. - + :return: Time to sleep in seconds (non-negative) :rtype: float """ diff --git a/tests/ez_test_utils.py b/tests/ez_test_utils.py index e31f01b5..04537489 100644 --- a/tests/ez_test_utils.py +++ b/tests/ez_test_utils.py @@ -4,7 +4,6 @@ import os from pathlib import Path import tempfile -import typing import ezmsg.core as ez @@ -24,7 +23,9 @@ def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path: # full test suite in parallel or when other tests use the same test name. # Use NamedTemporaryFile with delete=False so callers can open/remove it. prefix = f"{test_name}-" if test_name else "test-" - tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=f".{extension}", delete=False) + tmp = tempfile.NamedTemporaryFile( + prefix=prefix, suffix=f".{extension}", delete=False + ) tmp.close() return Path(tmp.name) diff --git a/tests/messages/test_axisarray.py b/tests/messages/test_axisarray.py index 6dfa79b0..e11bc474 100644 --- a/tests/messages/test_axisarray.py +++ b/tests/messages/test_axisarray.py @@ -1,3 +1,4 @@ +import importlib.util import pytest import numpy as np @@ -274,15 +275,8 @@ def test_sliding_win_oneaxis(nwin: int, axis: int, step: int): assert np.shares_memory(res, expected) -def xarray_available(): - try: - import xarray - - return True - except ImportError: - return False - except ValueError: - return False +def xarray_available() -> bool: + return importlib.util.find_spec("xarray") is not None @pytest.mark.skipif( diff --git a/tests/messages/test_key.py b/tests/messages/test_key.py index d5222680..02bb8da2 100644 --- a/tests/messages/test_key.py +++ b/tests/messages/test_key.py @@ -2,7 +2,6 @@ import copy import json import os -import typing import numpy as np diff --git a/tests/messages/test_modify.py b/tests/messages/test_modify.py index 277db8ad..91f3c708 100644 --- a/tests/messages/test_modify.py +++ b/tests/messages/test_modify.py @@ -1,5 +1,4 @@ import copy -import typing import numpy as np import pytest diff --git a/tests/test_graph.py b/tests/test_graph.py index e8927e70..5da45fb1 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -16,12 +16,8 @@ ("d", "e"), ] -simple_graph_2 = [ - ("w", "x"), - ("w", "y"), - ("x", "z"), - ("y", "z") -] +simple_graph_2 = [("w", "x"), ("w", "y"), ("x", "z"), ("y", "z")] + @pytest.mark.asyncio async def test_pub_first(): @@ -90,12 +86,11 @@ async def test_comms_simple(): await context.connect("a", "b") a_pub = await context.publisher("a") await a_pub.broadcast("HELLO") - print('DONE BROADCASTING') + print("DONE BROADCASTING") msg = await b_sub.recv() assert msg == "HELLO" - @pytest.mark.asyncio async def test_comms(): async with GraphContext() as context: @@ -235,8 +230,8 @@ async def run_async(self) -> None: if start is not None: delta = time.perf_counter() - start message_rate = int(self.n_msgs / delta) - print(f"{ message_rate } msgs/sec") - print(f"{ msg_size * message_rate / 1024 / 1024 / 1024 } GB/sec") + print(f"{message_rate} msgs/sec") + print(f"{msg_size * message_rate / 1024 / 1024 / 1024} GB/sec") await asyncio.get_running_loop().run_in_executor( None, self.stop_barrier.wait diff --git a/tests/test_new_messaging.py b/tests/test_new_messaging.py index 782187ec..683f1857 100644 --- a/tests/test_new_messaging.py +++ b/tests/test_new_messaging.py @@ -7,23 +7,26 @@ PORT = 12345 MAX_COUNT = 100 -TOPIC = '/TEST' +TOPIC = "/TEST" + async def handle_pub(pub: Publisher) -> None: - print('Publisher Task Launched') + print("Publisher Task Launched") count = 0 while True: - await pub.broadcast(f'{count=}') + await pub.broadcast(f"{count=}") await asyncio.sleep(0.1) count += 1 - if count >= MAX_COUNT: break + if count >= MAX_COUNT: + break + + print("Publisher Task Concluded") - print('Publisher Task Concluded') async def handle_sub(sub: Subscriber) -> None: - print('Subscriber Task Launched') + print("Subscriber Task Launched") rx_count = 0 while True: @@ -32,24 +35,24 @@ async def handle_sub(sub: Subscriber) -> None: print(msg) rx_count += 1 - if rx_count >= MAX_COUNT: break - - print('Subscriber Task Concluded') + if rx_count >= MAX_COUNT: + break + + print("Subscriber Task Concluded") -async def host(host: str = '127.0.0.1'): +async def host(host: str = "127.0.0.1"): # Manually create a GraphServer server = GraphServer() server.start((host, PORT)) - print(f'Created GraphServer @ {server.address}') + print(f"Created GraphServer @ {server.address}") # Create a graph_service that will interact with this GraphServer graph_service = GraphService((host, PORT)) await graph_service.ensure() try: - test_pub = await Publisher.create(TOPIC, (host, PORT), host=host) test_sub1 = await Subscriber.create(TOPIC, (host, PORT)) test_sub2 = await Subscriber.create(TOPIC, (host, PORT)) @@ -61,30 +64,32 @@ async def host(host: str = '127.0.0.1'): sub_task_2 = asyncio.Task(handle_sub(test_sub2)) await asyncio.wait([pub_task, sub_task_1, sub_task_2]) - + test_pub.close() test_sub1.close() test_sub2.close() - for future in asyncio.as_completed([ - test_pub.wait_closed(), - test_sub1.wait_closed(), - test_sub2.wait_closed(), - ]): + for future in asyncio.as_completed( + [ + test_pub.wait_closed(), + test_sub1.wait_closed(), + test_sub2.wait_closed(), + ] + ): await future finally: server.stop() - print('Done') + print("Done") -async def attach_client(host: str = '127.0.0.1'): +async def attach_client(host: str = "127.0.0.1"): # Attach to a running GraphServer graph_service = GraphService((host, PORT)) await graph_service.ensure() - print(f'Connected to GraphServer @ {graph_service.address}') + print(f"Connected to GraphServer @ {graph_service.address}") sub = await Subscriber.create(TOPIC, (host, PORT)) @@ -96,20 +101,20 @@ async def attach_client(host: str = '127.0.0.1'): except asyncio.CancelledError: pass - + finally: sub.close() await sub.wait_closed() - print(f'Detached') + print("Detached") -if __name__ == '__main__': +if __name__ == "__main__": from dataclasses import dataclass from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument('--attach', action = 'store_true', help = 'attach to running graph') - parser.add_argument('--host', default = "0.0.0.0", help = 'hostname for graphserver') + parser.add_argument("--attach", action="store_true", help="attach to running graph") + parser.add_argument("--host", default="0.0.0.0", help="hostname for graphserver") @dataclass class Args: @@ -119,6 +124,6 @@ class Args: args = Args(**vars(parser.parse_args())) if args.attach: - asyncio.run(attach_client(host = args.host)) + asyncio.run(attach_client(host=args.host)) else: - asyncio.run(host(host = args.host)) \ No newline at end of file + asyncio.run(host(host=args.host)) diff --git a/tests/test_reconstitute.py b/tests/test_reconstitute.py index f81c04bd..80e77784 100644 --- a/tests/test_reconstitute.py +++ b/tests/test_reconstitute.py @@ -13,31 +13,36 @@ import numpy as np -ADDR = ('127.0.0.1', 12345) +ADDR = ("127.0.0.1", 12345) ITERS = 10000 + def setup_graph_and_shm(msg_size): server = GraphServer() server.start(ADDR) graph_service = GraphService(ADDR) loop = asyncio.get_event_loop() loop.run_until_complete(graph_service.ensure()) - aa = AxisArray(data=np.random.normal(size=(msg_size,)), dims=['a']) + aa = AxisArray(data=np.random.normal(size=(msg_size,)), dims=["a"]) with MessageMarshal.serialize(0, aa) as (size, _, _): msg_size_bytes = size - shm = loop.run_until_complete(graph_service.create_shm(num_buffers=32, buf_size=msg_size_bytes + 100)) + shm = loop.run_until_complete( + graph_service.create_shm(num_buffers=32, buf_size=msg_size_bytes + 100) + ) with shm.buffer(0) as mem: MessageMarshal.to_mem(0, aa, mem) return server, shm, msg_size_bytes + def recon_with_deepcopy_func(shm): with shm.buffer(0, readonly=True) as mem: with MessageMarshal.obj_from_mem(mem) as obj: deepcopy(obj) + def recon_without_deepcopy_func(shm): with shm.buffer(0, readonly=True) as mem: - with MessageMarshal.obj_from_mem(mem) as obj: + with MessageMarshal.obj_from_mem(mem) as _obj: pass @@ -58,17 +63,19 @@ def main(): recon_without_deepcopy_func(shm) # Time a single reconstitution and a single deepcopy - recon_time = timeit.timeit( - stmt=lambda: recon_without_deepcopy_func(shm), - number=ITERS - ) / ITERS * 1e9 # ns - deepcopy_time = timeit.timeit( - stmt=lambda: recon_with_deepcopy_func(shm), - number=ITERS - ) / ITERS * 1e9 # ns - - print(f'Per reconstitution (ns): {recon_time}') - print(f'Per deepcopy (ns): {deepcopy_time}') + recon_time = ( + timeit.timeit(stmt=lambda: recon_without_deepcopy_func(shm), number=ITERS) + / ITERS + * 1e9 + ) # ns + deepcopy_time = ( + timeit.timeit(stmt=lambda: recon_with_deepcopy_func(shm), number=ITERS) + / ITERS + * 1e9 + ) # ns + + print(f"Per reconstitution (ns): {recon_time}") + print(f"Per deepcopy (ns): {deepcopy_time}") # For each fanout, compute total time for both strategies total_recon = [] @@ -80,9 +87,9 @@ def main(): total_recon.append(fanout * recon_time) results[msg_size_bytes] = { - 'fanouts': fanouts, - 'total_recon': total_recon, - 'total_deepcopy': total_deepcopy, + "fanouts": fanouts, + "total_recon": total_recon, + "total_deepcopy": total_deepcopy, } server.stop() @@ -91,29 +98,34 @@ def main(): if plt is not None: plt.figure(figsize=(12, 8)) for msg_size_bytes in msg_size_bytes_list: - fanouts = results[msg_size_bytes]['fanouts'] + fanouts = results[msg_size_bytes]["fanouts"] plt.plot( fanouts, - results[msg_size_bytes]['total_recon'], - label=f'Recon N times ({msg_size_bytes} bytes)', - marker='o', linestyle='--', alpha=0.7 + results[msg_size_bytes]["total_recon"], + label=f"Recon N times ({msg_size_bytes} bytes)", + marker="o", + linestyle="--", + alpha=0.7, ) plt.plot( fanouts, - results[msg_size_bytes]['total_deepcopy'], - label=f'Recon 1 + deepcopy N-1 ({msg_size_bytes} bytes)', - marker='x', linestyle='-', alpha=0.7 + results[msg_size_bytes]["total_deepcopy"], + label=f"Recon 1 + deepcopy N-1 ({msg_size_bytes} bytes)", + marker="x", + linestyle="-", + alpha=0.7, ) - plt.xscale('log', base=2) - plt.yscale('log') - plt.xlabel('Fanout (number of consumers)') - plt.ylabel('Total time (ns)') - plt.title('Fanout Tradeoff: Reconstitute N times vs Reconstitute+Deepcopy') - plt.legend(fontsize='small', ncol=2) - plt.grid(True, which='both', ls='--', alpha=0.5) + plt.xscale("log", base=2) + plt.yscale("log") + plt.xlabel("Fanout (number of consumers)") + plt.ylabel("Total time (ns)") + plt.title("Fanout Tradeoff: Reconstitute N times vs Reconstitute+Deepcopy") + plt.legend(fontsize="small", ncol=2) + plt.grid(True, which="both", ls="--", alpha=0.5) plt.tight_layout() - plt.savefig('fanout_tradeoff.png') - print('Plot saved as fanout_tradeoff.png') + plt.savefig("fanout_tradeoff.png") + print("Plot saved as fanout_tradeoff.png") + -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() From b13f7de8dcb9be7df3e1a7aa47533b8d0cb82673 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 20 Nov 2025 12:17:21 -0500 Subject: [PATCH 81/88] added unit tests for new functionality --- tests/test_channel.py | 93 +++++++++++++++++++++++++++++++ tests/test_channelmanager.py | 105 +++++++++++++++++++++++++++++++++++ tests/test_perf_configs.py | 99 +++++++++++++++++++++++++++++++++ 3 files changed, 297 insertions(+) create mode 100644 tests/test_channel.py create mode 100644 tests/test_channelmanager.py create mode 100644 tests/test_perf_configs.py diff --git a/tests/test_channel.py b/tests/test_channel.py new file mode 100644 index 00000000..9527bdcb --- /dev/null +++ b/tests/test_channel.py @@ -0,0 +1,93 @@ +import asyncio +from uuid import uuid4 + +import pytest + +from ezmsg.core.messagechannel import Channel +from ezmsg.core.messagecache import CacheMiss +from ezmsg.core.netprotocol import Command, uint64_to_bytes +from ezmsg.core.backpressure import Backpressure + + +class DummyWriter: + def __init__(self): + self.buffer: list[bytes] = [] + + def write(self, data: bytes) -> None: + self.buffer.append(data) + + +def _resolved_task(): + loop = asyncio.get_running_loop() + fut = loop.create_future() + fut.set_result(None) + return fut + + +@pytest.mark.asyncio +async def test_channel_acknowledges_remote_messages(): + channel = Channel(uuid4(), uuid4(), 2, None, None) + channel._pub_writer = DummyWriter() + channel._pub_task = _resolved_task() + channel._graph_task = _resolved_task() + + client_id = uuid4() + queue: asyncio.Queue = asyncio.Queue() + channel.register_client(client_id, queue) + + msg_id = 5 + payload = {"value": 42} + channel.cache.put_local(payload, msg_id) + channel._notify_clients(msg_id) + + assert queue.qsize() == 1 + queued_pub, queued_msg = queue.get_nowait() + assert queued_pub == channel.pub_id + assert queued_msg == msg_id + + with channel.get(msg_id, client_id) as obj: + assert obj == payload + + with pytest.raises(CacheMiss): + _ = channel.cache[msg_id] + + buf_idx = msg_id % channel.num_buffers + assert channel.backpressure.buffers[buf_idx].is_empty + + expected_ack = Command.RX_ACK.value + uint64_to_bytes(msg_id) + assert channel._pub_writer.buffer[-1] == expected_ack + + +@pytest.mark.asyncio +async def test_channel_releases_local_backpressure(monkeypatch): + channel = Channel(uuid4(), uuid4(), 2, None, None) + channel._pub_writer = DummyWriter() + channel._pub_task = _resolved_task() + channel._graph_task = _resolved_task() + + local_bp = Backpressure(channel.num_buffers) + channel.register_client(channel.pub_id, None, local_bp) + + client_id = uuid4() + queue: asyncio.Queue = asyncio.Queue() + channel.register_client(client_id, queue) + + msg_id = 3 + payload = "local" + channel.put_local(msg_id, payload) + + assert queue.qsize() == 1 + queue.get_nowait() + + with channel.get(msg_id, client_id) as obj: + assert obj == payload + + buf_idx = msg_id % channel.num_buffers + assert local_bp.buffers[buf_idx].is_empty + assert channel._pub_writer.buffer == [] + + +def test_channel_put_local_requires_local_backpressure(): + channel = Channel(uuid4(), uuid4(), 1, None, None) + with pytest.raises(ValueError): + channel.put_local(1, "no pub") diff --git a/tests/test_channelmanager.py b/tests/test_channelmanager.py new file mode 100644 index 00000000..0f78a6f8 --- /dev/null +++ b/tests/test_channelmanager.py @@ -0,0 +1,105 @@ +import asyncio +from dataclasses import dataclass +from uuid import uuid4 + +import pytest + +from ezmsg.core.channelmanager import ChannelManager +from ezmsg.core.backpressure import Backpressure +from ezmsg.core.netprotocol import Address +from ezmsg.core import channelmanager as channelmanager_module + + +@dataclass +class DummyChannel: + clients: dict + + def __init__(self): + self.clients = {} + self.closed = False + self.waited = False + self.local_bp: dict = {} + + def register_client(self, client_id, queue, local_backpressure): + self.clients[client_id] = queue + if local_backpressure is not None: + self.local_bp[client_id] = local_backpressure + + def unregister_client(self, client_id): + del self.clients[client_id] + + def close(self): + self.closed = True + + async def wait_closed(self): + self.waited = True + + +@pytest.mark.asyncio +async def test_channel_manager_reuses_existing_channel(monkeypatch): + dummy_channel = DummyChannel() + + async def fake_create(pub_id, address): + return dummy_channel + + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) + + manager = ChannelManager() + pub_id = uuid4() + client_one = uuid4() + client_two = uuid4() + + queue_one: asyncio.Queue = asyncio.Queue() + queue_two: asyncio.Queue = asyncio.Queue() + + channel_a = await manager.register(pub_id, client_one, queue_one) + channel_b = await manager.register(pub_id, client_two, queue_two) + + assert channel_a is dummy_channel + assert channel_b is dummy_channel + assert dummy_channel.clients[client_one] is queue_one + assert dummy_channel.clients[client_two] is queue_two + + +@pytest.mark.asyncio +async def test_channel_manager_unregister_closes_channel(monkeypatch): + dummy_channel = DummyChannel() + + async def fake_create(pub_id, address): + dummy_channel.address = address + return dummy_channel + + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) + + manager = ChannelManager() + pub_id = uuid4() + client_id = uuid4() + + queue: asyncio.Queue = asyncio.Queue() + await manager.register(pub_id, client_id, queue) + await manager.unregister(pub_id, client_id) + + default_address = Address.from_string(channelmanager_module.GRAPHSERVER_ADDR) + assert pub_id not in manager._registry[default_address] + assert dummy_channel.closed + assert dummy_channel.waited + + +@pytest.mark.asyncio +async def test_channel_manager_registers_local_publisher(monkeypatch): + dummy_channel = DummyChannel() + + async def fake_create(pub_id, address): + return dummy_channel + + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) + + manager = ChannelManager() + pub_id = uuid4() + local_bp = Backpressure(1) + + channel = await manager.register_local_pub(pub_id, local_bp) + + assert channel is dummy_channel + assert dummy_channel.clients[pub_id] is None + assert dummy_channel.local_bp[pub_id] is local_bp diff --git a/tests/test_perf_configs.py b/tests/test_perf_configs.py new file mode 100644 index 00000000..82773f64 --- /dev/null +++ b/tests/test_perf_configs.py @@ -0,0 +1,99 @@ +import contextlib +import os +import tempfile +from pathlib import Path + +import pytest + +from ezmsg.core.graphserver import GraphServer +from ezmsg.util.perf.impl import Communication, CONFIGS, perform_test + + +PERF_MAX_DURATION = 0.5 +PERF_NUM_MSGS = 8 +PERF_MSG_SIZES = [64, 2**20] +PERF_NUM_BUFFERS = 2 +CLIENTS_PER_CONFIG = { + "fanout": 2, + "fanin": 2, + "relay": 2, +} + + +def _run_perf_case( + config_name: str, + comm: Communication, + msg_size: int, + server: GraphServer, +) -> None: + metrics = perform_test( + n_clients=CLIENTS_PER_CONFIG[config_name], + max_duration=PERF_MAX_DURATION, + num_msgs=PERF_NUM_MSGS, + msg_size=msg_size, + buffers=PERF_NUM_BUFFERS, + comms=comm, + config=CONFIGS[config_name], + graph_address=server.address, + ) + assert ( + metrics.num_msgs > 0 + ), f"Failed to exchange messages for {config_name}/{comm.value}/msg={msg_size}" + + +@contextlib.contextmanager +def _file_lock(path: Path): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w+") as lock_file: + lock_file.write("0") + lock_file.flush() + if os.name == "nt": + import msvcrt + + msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) + else: + import fcntl + + fcntl.flock(lock_file, fcntl.LOCK_EX) + try: + yield + finally: + if os.name == "nt": + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lock_file, fcntl.LOCK_UN) + + +@pytest.fixture(scope="module") +def perf_graph_server(): + """ + Spin up a dedicated graph server on an ephemeral port for the perf smoke tests. + The shared instance keeps startup/shutdown overhead down while isolating it + from the canonical graph server that other tests rely on. + """ + lock_path = Path(tempfile.gettempdir()) / "ezmsg_perf_smoke.lock" + with _file_lock(lock_path): + server = GraphServer() + server.start() + try: + yield server + finally: + server.stop() + + +@pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") +@pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") +def test_fanout_perf(perf_graph_server, comm, msg_size): + _run_perf_case("fanout", comm, msg_size, perf_graph_server) + + +@pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") +@pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") +def test_fanin_perf(perf_graph_server, comm, msg_size): + _run_perf_case("fanin", comm, msg_size, perf_graph_server) + + +@pytest.mark.parametrize("msg_size", PERF_MSG_SIZES, ids=lambda s: f"msg={s}") +@pytest.mark.parametrize("comm", list(Communication), ids=lambda c: f"comm={c.value}") +def test_relay_perf(perf_graph_server, comm, msg_size): + _run_perf_case("relay", comm, msg_size, perf_graph_server) From 8ff0657ca608b853fc7fdd488f729f7d1a41e4b4 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 20 Nov 2025 12:24:38 -0500 Subject: [PATCH 82/88] chore: ruff formatting --- src/ezmsg/util/messages/chunker.py | 2 +- tests/test_perf_configs.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ezmsg/util/messages/chunker.py b/src/ezmsg/util/messages/chunker.py index a7bacbab..20229ec2 100644 --- a/src/ezmsg/util/messages/chunker.py +++ b/src/ezmsg/util/messages/chunker.py @@ -38,7 +38,7 @@ def array_chunker( :rtype: collections.abc.Generator[AxisArray, None, None] """ - if not type(data) == np.ndarray: + if not type(data) == np.ndarray: # noqa: E721 # hot path optimization data = np.array(data) n_chunks = int(np.ceil(data.shape[axis] / chunk_len)) tvec = np.arange(n_chunks, dtype=float) * chunk_len / fs + tzero diff --git a/tests/test_perf_configs.py b/tests/test_perf_configs.py index 82773f64..fc7178aa 100644 --- a/tests/test_perf_configs.py +++ b/tests/test_perf_configs.py @@ -36,9 +36,9 @@ def _run_perf_case( config=CONFIGS[config_name], graph_address=server.address, ) - assert ( - metrics.num_msgs > 0 - ), f"Failed to exchange messages for {config_name}/{comm.value}/msg={msg_size}" + assert metrics.num_msgs > 0, ( + f"Failed to exchange messages for {config_name}/{comm.value}/msg={msg_size}" + ) @contextlib.contextmanager From 4d3f3d8ae3067027f30c99f424e298dd1d2666a4 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 20 Nov 2025 12:38:09 -0500 Subject: [PATCH 83/88] fix for python 3.10 --- src/ezmsg/util/perf/impl.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 4ab6d7f2..ef3b0067 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -18,6 +18,14 @@ from .util import stable_perf +try: + StrEnum = enum.StrEnum # type: ignore[attr-defined] +except AttributeError: + class StrEnum(str, enum.Enum): + """Fallback for Python < 3.11 where enum.StrEnum is unavailable.""" + + pass + def collect( components: typing.Optional[typing.Mapping[str, ez.Component]] = None, @@ -225,7 +233,7 @@ def relay(config: ConfigSettings) -> Configuration: } -class Communication(enum.StrEnum): +class Communication(StrEnum): LOCAL = "local" SHM = "shm" SHM_SPREAD = "shm_spread" From b0997f4aa23955a09d6bfd20ee493f7f4bdf9fd2 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 20 Nov 2025 14:58:13 -0500 Subject: [PATCH 84/88] fix: windows perf tests --- src/ezmsg/util/perf/impl.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index ef3b0067..0e8460c8 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -4,6 +4,7 @@ import time import typing import enum +import sys import ezmsg.core as ez @@ -26,6 +27,9 @@ class StrEnum(str, enum.Enum): pass +TIME = time.monotonic +if sys.platform.startswith('win'): + TIME = time.perf_counter def collect( components: typing.Optional[typing.Mapping[str, ez.Component]] = None, @@ -96,16 +100,16 @@ async def initialize(self) -> None: @ez.publisher(OUTPUT) async def publish(self) -> typing.AsyncGenerator: ez.logger.info(f"Load test publisher started. (PID: {os.getpid()})") - start_time = time.time() + start_time = TIME() for _ in range(self.SETTINGS.num_msgs): - current_time = time.time() + current_time = TIME() if current_time - start_time >= self.SETTINGS.max_duration: break yield ( self.OUTPUT, LoadTestSample( - _timestamp=time.time(), + _timestamp=TIME(), counter=self.STATE.counter, dynamic_data=np.zeros( int(self.SETTINGS.dynamic_size // 4), dtype=np.float32 @@ -154,7 +158,7 @@ async def receive(self, sample: LoadTestSample) -> None: if sample.counter != counter + 1: ez.logger.warning(f"{sample.counter - counter - 1} samples skipped!") self.STATE.received_data.append( - (sample._timestamp, time.time(), sample.counter) + (sample._timestamp, TIME(), sample.counter) ) self.STATE.counters[sample.key] = sample.counter @@ -322,10 +326,13 @@ def calculate_metrics(sink: LoadTestSink) -> Metrics: ) rx_timestamps = np.array([rx_ts for _, rx_ts, _ in sink.STATE.received_data]) + rx_timestamps.sort() runtime = max_timestamp - min_timestamp num_samples = len(sink.STATE.received_data) samplerate_mean = num_samples / runtime - samplerate_median = 1.0 / float(np.median(np.diff(rx_timestamps))) + diff_timestamps = np.diff(rx_timestamps) + diff_timestamps = diff_timestamps[np.nonzero(diff_timestamps)] + samplerate_median = 1.0 / float(np.median(diff_timestamps)) latency_mean = total_latency / num_samples latency_median = list(sorted(latency))[len(latency) // 2] total_data = num_samples * sink.SETTINGS.dynamic_size From 9f6f94a2e143b9f23ea5c64c010477d8c4ea8265 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Thu, 20 Nov 2025 15:06:35 -0500 Subject: [PATCH 85/88] update git-blame-ignore-revs with more ruff formatting --- .git-blame-ignore-revs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 3b6d32f6..3c04e930 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -2,3 +2,5 @@ 5b4c220154d33cac15a75d3d7d978a75e67f2b8a # Ran `black` formatter on entire codebase and fixed non-standard line-endings with `dos2unix`. 4e4f20be40a73f2162a565732bae76ea0c812739 +# chore: ruff formatting +5ca5711e7714042f23d1286abf946217d75318c2 From 438b557bc012c0b8cf77348ea181d6b245e460f4 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Mon, 24 Nov 2025 21:27:01 -0500 Subject: [PATCH 86/88] addressed review --- .../lowlevel_api.py | 34 ++--- src/ezmsg/core/__init__.py | 4 + src/ezmsg/core/graphcontext.py | 2 +- src/ezmsg/core/messagecache.py | 25 +++- src/ezmsg/core/messagechannel.py | 12 +- src/ezmsg/core/messagemarshal.py | 2 +- src/ezmsg/core/pubclient.py | 21 ++- src/ezmsg/core/shm.py | 2 +- src/ezmsg/core/subclient.py | 16 ++- tests/ez_test_utils.py | 16 ++- tests/messages/test_key.py | 120 ++++++++-------- tests/test_channel.py | 6 +- tests/test_generator.py | 77 +++++----- tests/test_reconstitute.py | 131 ------------------ tests/test_run.py | 113 ++++++++------- 15 files changed, 248 insertions(+), 333 deletions(-) rename tests/test_new_messaging.py => examples/lowlevel_api.py (70%) delete mode 100644 tests/test_reconstitute.py diff --git a/tests/test_new_messaging.py b/examples/lowlevel_api.py similarity index 70% rename from tests/test_new_messaging.py rename to examples/lowlevel_api.py index 683f1857..9bc5d586 100644 --- a/tests/test_new_messaging.py +++ b/examples/lowlevel_api.py @@ -1,16 +1,13 @@ import asyncio -from ezmsg.core.graphserver import GraphServer, GraphService - -from ezmsg.core.subclient import Subscriber -from ezmsg.core.pubclient import Publisher +import ezmsg.core as ez PORT = 12345 MAX_COUNT = 100 TOPIC = "/TEST" -async def handle_pub(pub: Publisher) -> None: +async def handle_pub(pub: ez.Publisher) -> None: print("Publisher Task Launched") count = 0 @@ -25,13 +22,14 @@ async def handle_pub(pub: Publisher) -> None: print("Publisher Task Concluded") -async def handle_sub(sub: Subscriber) -> None: +async def handle_sub(sub: ez.Subscriber) -> None: print("Subscriber Task Launched") rx_count = 0 while True: async with sub.recv_zero_copy() as msg: - await asyncio.sleep(0.15) + # Uncomment if you want to witness backpressure! + # await asyncio.sleep(0.15) print(msg) rx_count += 1 @@ -43,19 +41,15 @@ async def handle_sub(sub: Subscriber) -> None: async def host(host: str = "127.0.0.1"): # Manually create a GraphServer - server = GraphServer() + server = ez.GraphServer() server.start((host, PORT)) print(f"Created GraphServer @ {server.address}") - # Create a graph_service that will interact with this GraphServer - graph_service = GraphService((host, PORT)) - await graph_service.ensure() - try: - test_pub = await Publisher.create(TOPIC, (host, PORT), host=host) - test_sub1 = await Subscriber.create(TOPIC, (host, PORT)) - test_sub2 = await Subscriber.create(TOPIC, (host, PORT)) + test_pub = await ez.Publisher.create(TOPIC, (host, PORT), host=host) + test_sub1 = await ez.Subscriber.create(TOPIC, (host, PORT)) + test_sub2 = await ez.Subscriber.create(TOPIC, (host, PORT)) await asyncio.sleep(1.0) @@ -85,18 +79,14 @@ async def host(host: str = "127.0.0.1"): async def attach_client(host: str = "127.0.0.1"): - # Attach to a running GraphServer - graph_service = GraphService((host, PORT)) - await graph_service.ensure() - - print(f"Connected to GraphServer @ {graph_service.address}") - sub = await Subscriber.create(TOPIC, (host, PORT)) + sub = await ez.Subscriber.create(TOPIC, (host, PORT)) try: while True: async with sub.recv_zero_copy() as msg: - await asyncio.sleep(1.0) + # Uncomment if you want to see EXTREME backpressure! + # await asyncio.sleep(1.0) print(msg) except asyncio.CancelledError: diff --git a/src/ezmsg/core/__init__.py b/src/ezmsg/core/__init__.py index e3ada5cd..fc361531 100644 --- a/src/ezmsg/core/__init__.py +++ b/src/ezmsg/core/__init__.py @@ -24,6 +24,8 @@ "GraphServer", "GraphContext", "run_command", + "Publisher", + "Subscriber", # All following are deprecated "System", "run_system", @@ -42,6 +44,8 @@ from .graphserver import GraphServer from .graphcontext import GraphContext from .command import run_command +from .pubclient import Publisher +from .subclient import Subscriber # Following imports are deprecated from .backend import run_system diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 544f4663..f773569a 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -108,7 +108,7 @@ async def disconnect(self, from_topic: str, to_topic: str) -> None: await GraphService(self.graph_address).disconnect(from_topic, to_topic) self._edges.discard((from_topic, to_topic)) - async def sync(self, timeout: typing.Optional[float] = None) -> None: + async def sync(self, timeout: float | None = None) -> None: """ Synchronize with the graph server. diff --git a/src/ezmsg/core/messagecache.py b/src/ezmsg/core/messagecache.py index c56b0458..6de8544c 100644 --- a/src/ezmsg/core/messagecache.py +++ b/src/ezmsg/core/messagecache.py @@ -3,7 +3,13 @@ from .messagemarshal import MessageMarshal -class CacheMiss(Exception): ... +class CacheMiss(Exception): + """ + Exception raised when a requested message is not found in cache. + + This occurs when trying to retrieve a message that has been evicted + from the cache or was never stored in the first place. + """ class CacheEntry(typing.NamedTuple): @@ -14,9 +20,22 @@ class CacheEntry(typing.NamedTuple): class MessageCache: + """ + Cache for memoryview-backed objects. + + Provides a buffer cache that can store objects in memory + enabling efficient message passing between + processes with automatic eviction based on buffer age. + """ _cache: list[CacheEntry | None] def __init__(self, num_buffers: int) -> None: + """ + Initialize the cache with specified number of buffers. + + :param num_buffers: Number of cache buffers to maintain. + :type num_buffers: int + """ self._cache = [None] * num_buffers def _buf_idx(self, msg_id: int) -> int: @@ -114,10 +133,6 @@ def release(self, msg_id: int) -> None: def clear(self) -> None: """ Release all cached objects - - :param mem: Source memoryview containing serialized object. - :type from_mem: memoryview - :raises UninitializedMemory: If mem buffer is not properly initialized. """ for i in range(len(self._cache)): self._release(i) diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 15dd180d..e016f101 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -40,6 +40,8 @@ class Channel: The Channel constructor should not be called directly, instead use Channel.create(...) """ + _SENTINEL = object() + id: UUID pub_id: UUID pid: int @@ -64,7 +66,14 @@ def __init__( num_buffers: int, shm: SHMContext | None, graph_address: AddressType | None, + _guard = None, ) -> None: + if _guard is not self._SENTINEL: + raise TypeError( + "Channel cannot be instantiated directly." + "Use 'await CHANNELS.register(...)' instead." + ) + self.id = id self.pub_id = pub_id self.num_buffers = num_buffers @@ -131,8 +140,9 @@ async def create( raise ValueError(f"failed to create channel {pub_id=}") num_buffers = await read_int(reader) + assert num_buffers > 0, "publisher reports invalid num_buffers" - chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address) + chan = cls(UUID(id_str), pub_id, num_buffers, shm, graph_address, _guard=cls._SENTINEL) chan._graph_task = asyncio.create_task( chan._graph_connection(graph_reader, graph_writer), diff --git a/src/ezmsg/core/messagemarshal.py b/src/ezmsg/core/messagemarshal.py index 779fcc81..4621d642 100644 --- a/src/ezmsg/core/messagemarshal.py +++ b/src/ezmsg/core/messagemarshal.py @@ -99,7 +99,7 @@ def msg_id(cls, raw: memoryview | bytes) -> int: :type mem: memoryview | bytes :return: Message ID of encoded message :rtype: int - :raises UndersizedMemory: If buffer is not initialized. + :raises UninitializedMemory: If buffer is not initialized. """ cls._assert_initialized(raw) return bytes_to_uint(raw[_PREAMBLE_LEN : _PREAMBLE_LEN + UINT64_SIZE]) diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index 0094f7d5..e12b71e1 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -56,6 +56,8 @@ class Publisher: and resource management. """ + _SENTINEL = object() + id: UUID pid: int topic: str @@ -134,6 +136,7 @@ async def create( num_buffers=num_buffers, start_paused=start_paused, force_tcp=force_tcp, + _guard=cls._SENTINEL, ) start_port = int( @@ -187,6 +190,7 @@ def __init__( num_buffers: int = 32, start_paused: bool = False, force_tcp: bool = False, + _guard = None ) -> None: """ Initialize a Publisher instance. @@ -205,6 +209,12 @@ def __init__( :param force_tcp: Whether to force TCP transport instead of shared memory. :type force_tcp: bool """ + if _guard is not self._SENTINEL: + raise TypeError( + "Publisher cannot be instantiated directly." + "Use 'await Publisher.create(...)' instead." + ) + self.id = id self.pid = os.getpid() self.topic = topic @@ -301,10 +311,10 @@ async def _channel_connect( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: """ - Handle new subscriber connections. + Handle new channel connections. - Exchanges identification information with connecting subscribers - and sets up subscriber handling tasks. + Exchanges identification information with connecting channels + and sets up channel handling tasks. :param reader: Stream reader for receiving subscriber info. :type reader: asyncio.StreamReader @@ -326,8 +336,11 @@ async def _channel_connect( coro = self._handle_channel(info, reader) self._channel_tasks[channel_id] = asyncio.create_task(coro) writer.write(Command.COMPLETE.value + uint64_to_bytes(self._num_buffers)) + await writer.drain() + + else: + raise ValueError(f"Publisher {self.id}: unexpected command {cmd=}") - await writer.drain() async def _handle_channel( self, info: PubChannelInfo, reader: asyncio.StreamReader diff --git a/src/ezmsg/core/shm.py b/src/ezmsg/core/shm.py index abfdc2d2..279aeb9c 100644 --- a/src/ezmsg/core/shm.py +++ b/src/ezmsg/core/shm.py @@ -120,7 +120,7 @@ async def _graph_connection( ) -> None: try: await reader.read() - logger.debug(f"SHMContext {self.name} GraphServer hangup") + logger.debug(f"SHMContext {self.name} GraphServer disconnected; closing.") except (ConnectionResetError, BrokenPipeError) as e: logger.debug(f"SHMContext {self.name} GraphServer {type(e)}") finally: diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 981233f4..fee285cc 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -31,6 +31,8 @@ class Subscriber: and zero-copy message access patterns with automatic acknowledgment. """ + _SENTINEL = object() + id: UUID topic: str @@ -72,7 +74,7 @@ async def create( sub_id_str = await read_str(reader) sub_id = UUID(sub_id_str) - sub = cls(sub_id, topic, graph_address, **kwargs) + sub = cls(sub_id, topic, graph_address, _guard=cls._SENTINEL, **kwargs) sub._graph_task = asyncio.create_task( sub._graph_connection(reader, writer), @@ -89,7 +91,12 @@ async def create( return sub def __init__( - self, id: UUID, topic: str, graph_address: AddressType | None, **kwargs + self, + id: UUID, + topic: str, + graph_address: AddressType | None, + _guard = None, + **kwargs ) -> None: """ Initialize a Subscriber instance. @@ -104,6 +111,11 @@ def __init__( :type graph_service: GraphService :param kwargs: Additional keyword arguments (unused). """ + if _guard is not self._SENTINEL: + raise TypeError( + "Subscriber cannot be instantiated directly." + "Use 'await Subscriber.create(...)' instead." + ) self.id = id self.topic = topic self._graph_address = graph_address diff --git a/tests/ez_test_utils.py b/tests/ez_test_utils.py index 04537489..f651b2a8 100644 --- a/tests/ez_test_utils.py +++ b/tests/ez_test_utils.py @@ -1,14 +1,18 @@ from dataclasses import asdict, dataclass from collections.abc import AsyncGenerator +from contextlib import contextmanager + import json import os from pathlib import Path import tempfile +import typing import ezmsg.core as ez -def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path: +@contextmanager +def get_test_fn(test_name: str | None = None, extension: str = "txt") -> typing.Generator[Path, None, None]: """PYTEST compatible temporary test file creator""" # Get current test name if we can.. @@ -23,11 +27,11 @@ def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path: # full test suite in parallel or when other tests use the same test name. # Use NamedTemporaryFile with delete=False so callers can open/remove it. prefix = f"{test_name}-" if test_name else "test-" - tmp = tempfile.NamedTemporaryFile( - prefix=prefix, suffix=f".{extension}", delete=False - ) - tmp.close() - return Path(tmp.name) + tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=f".{extension}") + try: + yield Path(tmp.name) + finally: + tmp.close() # MESSAGE DEFINITIONS diff --git a/tests/messages/test_key.py b/tests/messages/test_key.py index 02bb8da2..aa44ee60 100644 --- a/tests/messages/test_key.py +++ b/tests/messages/test_key.py @@ -74,71 +74,69 @@ async def on_message(self, msg: AxisArray) -> None: def test_set_key_unit(): num_msgs = 20 - test_fn_raw = get_test_fn() - test_fn_keyed = test_fn_raw.parent / ( - test_fn_raw.stem + "_keyed" + test_fn_raw.suffix - ) - - comps = { - "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), - "KEYSETTER": SetKey(key="new key"), - "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), - "SINK_KEYED": AxarrReceiver(num_msgs=num_msgs, output_fn=test_fn_keyed), - } - conns = [ - (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), - (comps["SRC"].OUTPUT_SIGNAL, comps["KEYSETTER"].INPUT_SIGNAL), - (comps["KEYSETTER"].OUTPUT_SIGNAL, comps["SINK_KEYED"].INPUT_SIGNAL), - ] - ez.run(components=comps, connections=conns) - - with open(test_fn_raw, "r") as file: - raw_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_raw) - assert len(raw_results) == num_msgs - assert all( - [ - _[str(ix + 1)] == ("odd" if ix % 2 else "even") - for ix, _ in enumerate(raw_results) + with get_test_fn() as test_fn_raw: + test_fn_keyed = test_fn_raw.parent / ( + test_fn_raw.stem + "_keyed" + test_fn_raw.suffix + ) + + comps = { + "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), + "KEYSETTER": SetKey(key="new key"), + "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), + "SINK_KEYED": AxarrReceiver(num_msgs=num_msgs, output_fn=test_fn_keyed), + } + conns = [ + (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), + (comps["SRC"].OUTPUT_SIGNAL, comps["KEYSETTER"].INPUT_SIGNAL), + (comps["KEYSETTER"].OUTPUT_SIGNAL, comps["SINK_KEYED"].INPUT_SIGNAL), ] - ) + ez.run(components=comps, connections=conns) + + with open(test_fn_raw, "r") as file: + raw_results = [json.loads(_) for _ in file.readlines()] + assert len(raw_results) == num_msgs + assert all( + [ + _[str(ix + 1)] == ("odd" if ix % 2 else "even") + for ix, _ in enumerate(raw_results) + ] + ) - with open(test_fn_keyed, "r") as file: - keyed_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_keyed) - assert len(keyed_results) == num_msgs - assert all([_[str(ix + 1)] == "new key" for ix, _ in enumerate(keyed_results)]) + with open(test_fn_keyed, "r") as file: + keyed_results = [json.loads(_) for _ in file.readlines()] + os.remove(test_fn_keyed) + assert len(keyed_results) == num_msgs + assert all([_[str(ix + 1)] == "new key" for ix, _ in enumerate(keyed_results)]) def test_filter_key(): num_msgs = 20 - test_fn_raw = get_test_fn() - test_fn_filtered = test_fn_raw.parent / ( - test_fn_raw.stem + "_keyed" + test_fn_raw.suffix - ) + with get_test_fn() as test_fn_raw: + test_fn_filtered = test_fn_raw.parent / ( + test_fn_raw.stem + "_keyed" + test_fn_raw.suffix + ) + + comps = { + "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), + "FILTER": FilterOnKey(key="odd"), + "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), + "SINK_FILTERED": AxarrReceiver( + num_msgs=num_msgs // 2, output_fn=test_fn_filtered + ), + } + conns = [ + (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), + (comps["SRC"].OUTPUT_SIGNAL, comps["FILTER"].INPUT_SIGNAL), + (comps["FILTER"].OUTPUT_SIGNAL, comps["SINK_FILTERED"].INPUT_SIGNAL), + ] + ez.run(components=comps, connections=conns) + + with open(test_fn_raw, "r") as file: + raw_results = [json.loads(_) for _ in file.readlines()] + assert len(raw_results) == num_msgs - comps = { - "SRC": KeyedAxarrGenerator(num_msgs=num_msgs), - "FILTER": FilterOnKey(key="odd"), - "SINK_RAW": AxarrReceiver(num_msgs=1e9, output_fn=test_fn_raw), - "SINK_FILTERED": AxarrReceiver( - num_msgs=num_msgs // 2, output_fn=test_fn_filtered - ), - } - conns = [ - (comps["SRC"].OUTPUT_SIGNAL, comps["SINK_RAW"].INPUT_SIGNAL), - (comps["SRC"].OUTPUT_SIGNAL, comps["FILTER"].INPUT_SIGNAL), - (comps["FILTER"].OUTPUT_SIGNAL, comps["SINK_FILTERED"].INPUT_SIGNAL), - ] - ez.run(components=comps, connections=conns) - - with open(test_fn_raw, "r") as file: - raw_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_raw) - assert len(raw_results) == num_msgs - - with open(test_fn_filtered, "r") as file: - filtered_results = [json.loads(_) for _ in file.readlines()] - os.remove(test_fn_filtered) - assert len(filtered_results) == num_msgs // 2 - assert all([_[str(ix + 1)] == "odd" for ix, _ in enumerate(filtered_results)]) + with open(test_fn_filtered, "r") as file: + filtered_results = [json.loads(_) for _ in file.readlines()] + os.remove(test_fn_filtered) + assert len(filtered_results) == num_msgs // 2 + assert all([_[str(ix + 1)] == "odd" for ix, _ in enumerate(filtered_results)]) diff --git a/tests/test_channel.py b/tests/test_channel.py index 9527bdcb..e45f7192 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -26,7 +26,7 @@ def _resolved_task(): @pytest.mark.asyncio async def test_channel_acknowledges_remote_messages(): - channel = Channel(uuid4(), uuid4(), 2, None, None) + channel = Channel(uuid4(), uuid4(), 2, None, None, Channel._SENTINEL) channel._pub_writer = DummyWriter() channel._pub_task = _resolved_task() channel._graph_task = _resolved_task() @@ -60,7 +60,7 @@ async def test_channel_acknowledges_remote_messages(): @pytest.mark.asyncio async def test_channel_releases_local_backpressure(monkeypatch): - channel = Channel(uuid4(), uuid4(), 2, None, None) + channel = Channel(uuid4(), uuid4(), 2, None, None, Channel._SENTINEL) channel._pub_writer = DummyWriter() channel._pub_task = _resolved_task() channel._graph_task = _resolved_task() @@ -88,6 +88,6 @@ async def test_channel_releases_local_backpressure(monkeypatch): def test_channel_put_local_requires_local_backpressure(): - channel = Channel(uuid4(), uuid4(), 1, None, None) + channel = Channel(uuid4(), uuid4(), 1, None, None, Channel._SENTINEL) with pytest.raises(ValueError): channel.put_local(1, "no pub") diff --git a/tests/test_generator.py b/tests/test_generator.py index 37921cc9..df013006 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -163,25 +163,25 @@ def test_gen_to_unit_any(): assert MyUnit.OUTPUT.msg_type == list[typing.Any] num_msgs = 5 - test_filename = get_test_fn() - comps = { - "SIMPLE_PUB": MessageGenerator(num_msgs=num_msgs), - "MYUNIT": MyUnit(), - "SIMPLE_SUB": MessageAnyReceiver(num_msgs=num_msgs, output_fn=test_filename), - } - ez.run( - components=comps, - connections=( - (comps["SIMPLE_PUB"].OUTPUT, comps["MYUNIT"].INPUT), - (comps["MYUNIT"].OUTPUT, comps["SIMPLE_SUB"].INPUT), - ), - ) - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) + with get_test_fn() as test_filename: + comps = { + "SIMPLE_PUB": MessageGenerator(num_msgs=num_msgs), + "MYUNIT": MyUnit(), + "SIMPLE_SUB": MessageAnyReceiver(num_msgs=num_msgs, output_fn=test_filename), + } + ez.run( + components=comps, + connections=( + (comps["SIMPLE_PUB"].OUTPUT, comps["MYUNIT"].INPUT), + (comps["MYUNIT"].OUTPUT, comps["SIMPLE_SUB"].INPUT), + ), + ) + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + # We don't really care about the contents; functionality was confirmed in a separate test. # Keep this simple. assert len(results) == num_msgs @@ -224,25 +224,26 @@ def test_gen_to_unit_axarr(): assert MyUnit.OUTPUT_SIGNAL.msg_type is AxisArray num_msgs = 5 - test_filename = get_test_fn() - comps = { - "SIMPLE_PUB": AxarrGenerator(num_msgs=num_msgs), - "MYUNIT": MyUnit(), - "SIMPLE_SUB": AxarrReceiver(num_msgs=num_msgs, output_fn=test_filename), - } - ez.run( - components=comps, - connections=( - (comps["SIMPLE_PUB"].OUTPUT_SIGNAL, comps["MYUNIT"].INPUT_SIGNAL), - (comps["MYUNIT"].OUTPUT_SIGNAL, comps["SIMPLE_SUB"].INPUT_SIGNAL), - ), - ) - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) + + with get_test_fn() as test_filename: + comps = { + "SIMPLE_PUB": AxarrGenerator(num_msgs=num_msgs), + "MYUNIT": MyUnit(), + "SIMPLE_SUB": AxarrReceiver(num_msgs=num_msgs, output_fn=test_filename), + } + ez.run( + components=comps, + connections=( + (comps["SIMPLE_PUB"].OUTPUT_SIGNAL, comps["MYUNIT"].INPUT_SIGNAL), + (comps["MYUNIT"].OUTPUT_SIGNAL, comps["SIMPLE_SUB"].INPUT_SIGNAL), + ), + ) + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + assert np.array_equal( results[-1][f"{num_msgs}"], np.hstack([np.arange(_) for _ in range(num_msgs)]) ) diff --git a/tests/test_reconstitute.py b/tests/test_reconstitute.py deleted file mode 100644 index 80e77784..00000000 --- a/tests/test_reconstitute.py +++ /dev/null @@ -1,131 +0,0 @@ -import asyncio -from copy import deepcopy -import timeit - -try: - import matplotlib.pyplot as plt -except ImportError: - plt = None - -from ezmsg.core.graphserver import GraphServer, GraphService -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.core.messagemarshal import MessageMarshal - -import numpy as np - -ADDR = ("127.0.0.1", 12345) -ITERS = 10000 - - -def setup_graph_and_shm(msg_size): - server = GraphServer() - server.start(ADDR) - graph_service = GraphService(ADDR) - loop = asyncio.get_event_loop() - loop.run_until_complete(graph_service.ensure()) - aa = AxisArray(data=np.random.normal(size=(msg_size,)), dims=["a"]) - with MessageMarshal.serialize(0, aa) as (size, _, _): - msg_size_bytes = size - shm = loop.run_until_complete( - graph_service.create_shm(num_buffers=32, buf_size=msg_size_bytes + 100) - ) - with shm.buffer(0) as mem: - MessageMarshal.to_mem(0, aa, mem) - return server, shm, msg_size_bytes - - -def recon_with_deepcopy_func(shm): - with shm.buffer(0, readonly=True) as mem: - with MessageMarshal.obj_from_mem(mem) as obj: - deepcopy(obj) - - -def recon_without_deepcopy_func(shm): - with shm.buffer(0, readonly=True) as mem: - with MessageMarshal.obj_from_mem(mem) as _obj: - pass - - -def main(): - msg_sizes = [2**i for i in range(8, 25, 2)] # 2**8, 2**10, ..., 2**24 - fanouts = [1, 2, 4, 8, 16, 32] - results = {} - msg_size_bytes_list = [] - - for msg_size in msg_sizes: - print(f"\nTesting MSG_SIZE={msg_size}") - server, shm, msg_size_bytes = setup_graph_and_shm(msg_size) - msg_size_bytes_list.append(msg_size_bytes) - - # Warm up - for _ in range(100): - recon_with_deepcopy_func(shm) - recon_without_deepcopy_func(shm) - - # Time a single reconstitution and a single deepcopy - recon_time = ( - timeit.timeit(stmt=lambda: recon_without_deepcopy_func(shm), number=ITERS) - / ITERS - * 1e9 - ) # ns - deepcopy_time = ( - timeit.timeit(stmt=lambda: recon_with_deepcopy_func(shm), number=ITERS) - / ITERS - * 1e9 - ) # ns - - print(f"Per reconstitution (ns): {recon_time}") - print(f"Per deepcopy (ns): {deepcopy_time}") - - # For each fanout, compute total time for both strategies - total_recon = [] - total_deepcopy = [] - for fanout in fanouts: - # Strategy 1: 1 reconstitution + 1 deepcopy, rest are free - total_deepcopy.append(recon_time + deepcopy_time) - # Strategy 2: reconstitute N times (no deepcopy) - total_recon.append(fanout * recon_time) - - results[msg_size_bytes] = { - "fanouts": fanouts, - "total_recon": total_recon, - "total_deepcopy": total_deepcopy, - } - - server.stop() - - # Plot results - if plt is not None: - plt.figure(figsize=(12, 8)) - for msg_size_bytes in msg_size_bytes_list: - fanouts = results[msg_size_bytes]["fanouts"] - plt.plot( - fanouts, - results[msg_size_bytes]["total_recon"], - label=f"Recon N times ({msg_size_bytes} bytes)", - marker="o", - linestyle="--", - alpha=0.7, - ) - plt.plot( - fanouts, - results[msg_size_bytes]["total_deepcopy"], - label=f"Recon 1 + deepcopy N-1 ({msg_size_bytes} bytes)", - marker="x", - linestyle="-", - alpha=0.7, - ) - plt.xscale("log", base=2) - plt.yscale("log") - plt.xlabel("Fanout (number of consumers)") - plt.ylabel("Total time (ns)") - plt.title("Fanout Tradeoff: Reconstitute N times vs Reconstitute+Deepcopy") - plt.legend(fontsize="small", ncol=2) - plt.grid(True, which="both", ls="--", alpha=0.5) - plt.tight_layout() - plt.savefig("fanout_tradeoff.png") - print("Plot saved as fanout_tradeoff.png") - - -if __name__ == "__main__": - main() diff --git a/tests/test_run.py b/tests/test_run.py index bc5046c2..323bc4ba 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -63,70 +63,69 @@ def func(self): @pytest.mark.parametrize("num_messages", [1, 5, 10]) def test_local_system(toy_system_fixture, num_messages): - test_filename = get_test_fn() - system = toy_system_fixture( - ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) - ) - ez.run(SYSTEM=system, profiler_log_name="test_profiler.log") - assert os.environ.get("EZMSG_PROFILER") == "test_profiler.log" - - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) - assert len(results) == num_messages + with get_test_fn() as test_filename: + system = toy_system_fixture( + ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) + ) + ez.run(SYSTEM=system, profiler_log_name="test_profiler.log") + assert os.environ.get("EZMSG_PROFILER") == "test_profiler.log" + + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + + assert len(results) == num_messages @pytest.mark.parametrize("passthrough_settings", [False, True]) @pytest.mark.parametrize("num_messages", [1, 5, 10]) def test_run_comps_conns(passthrough_settings, num_messages): - test_filename = get_test_fn() - if passthrough_settings: - comps = { - "SIMPLE_PUB": MessageGenerator(num_msgs=num_messages), - "SIMPLE_SUB": MessageReceiver( - num_msgs=num_messages, output_fn=test_filename - ), - } - else: - comps = { - "SIMPLE_PUB": MessageGenerator( - MessageGeneratorSettings(num_msgs=num_messages) - ), - "SIMPLE_SUB": MessageReceiver( - MessageReceiverSettings(num_msgs=num_messages, output_fn=test_filename) - ), - } - conns = ((comps["SIMPLE_PUB"].OUTPUT, comps["SIMPLE_SUB"].INPUT),) - - ez.run(components=comps, connections=conns) - - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) - assert len(results) == num_messages + with get_test_fn() as test_filename: + if passthrough_settings: + comps = { + "SIMPLE_PUB": MessageGenerator(num_msgs=num_messages), + "SIMPLE_SUB": MessageReceiver( + num_msgs=num_messages, output_fn=test_filename + ), + } + else: + comps = { + "SIMPLE_PUB": MessageGenerator( + MessageGeneratorSettings(num_msgs=num_messages) + ), + "SIMPLE_SUB": MessageReceiver( + MessageReceiverSettings(num_msgs=num_messages, output_fn=test_filename) + ), + } + conns = ((comps["SIMPLE_PUB"].OUTPUT, comps["SIMPLE_SUB"].INPUT),) + + ez.run(components=comps, connections=conns) + + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + + assert len(results) == num_messages @pytest.mark.parametrize("passthrough_settings", [False, True]) @pytest.mark.parametrize("num_messages", [1, 5, 10]) def test_run_collection(passthrough_settings, num_messages): - test_filename = get_test_fn() - if passthrough_settings: - collection = ToySystem(num_msgs=num_messages, output_fn=test_filename) - else: - collection = ToySystem( - ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) - ) - ez.run(collection) - results = [] - with open(test_filename, "r") as file: - lines = file.readlines() - for line in lines: - results.append(json.loads(line)) - os.remove(test_filename) - assert len(results) == num_messages + with get_test_fn() as test_filename: + if passthrough_settings: + collection = ToySystem(num_msgs=num_messages, output_fn=test_filename) + else: + collection = ToySystem( + ToySystemSettings(num_msgs=num_messages, output_fn=test_filename) + ) + ez.run(collection) + results = [] + with open(test_filename, "r") as file: + lines = file.readlines() + for line in lines: + results.append(json.loads(line)) + assert len(results) == num_messages From a21073e4b36950366faa1c08a7c8a7717d2983dc Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 25 Nov 2025 11:47:50 -0500 Subject: [PATCH 87/88] updated docstring --- src/ezmsg/core/messagecache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ezmsg/core/messagecache.py b/src/ezmsg/core/messagecache.py index 6de8544c..f14cd2f7 100644 --- a/src/ezmsg/core/messagecache.py +++ b/src/ezmsg/core/messagecache.py @@ -120,9 +120,9 @@ def release(self, msg_id: int) -> None: """ Release memory for the entry associated with msg_id - :param mem: Source memoryview containing serialized object. - :type from_mem: memoryview - :raises UninitializedMemory: If mem buffer is not properly initialized. + :param msg_id: ID for the message to release. + :type msg_id: int + :raises CacheMiss: If requested msg_id is not in cache. """ buf_idx = self._buf_idx(msg_id) entry = self._cache[buf_idx] From 6132479fdbf68b65a030d403ba9bec6de2f8aa39 Mon Sep 17 00:00:00 2001 From: Griffin Milsap Date: Tue, 25 Nov 2025 11:48:31 -0500 Subject: [PATCH 88/88] Update tests/ez_test_utils.py Co-authored-by: Konrad Pilch --- tests/ez_test_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/ez_test_utils.py b/tests/ez_test_utils.py index f651b2a8..6ac7134f 100644 --- a/tests/ez_test_utils.py +++ b/tests/ez_test_utils.py @@ -28,10 +28,15 @@ def get_test_fn(test_name: str | None = None, extension: str = "txt") -> typing. # Use NamedTemporaryFile with delete=False so callers can open/remove it. prefix = f"{test_name}-" if test_name else "test-" tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=f".{extension}") + tmp.close() # Close so others can open it on Windows + path = Path(tmp.name) try: - yield Path(tmp.name) + yield path finally: - tmp.close() + try: + path.unlink() + except FileNotFoundError: + pass # MESSAGE DEFINITIONS