From 6d0e89754b56ea92fd454a5e093c4ae4da096ad1 Mon Sep 17 00:00:00 2001 From: ljwoods2 Date: Mon, 10 Nov 2025 14:30:07 -0700 Subject: [PATCH] server --- imdclient/IMDClient.py | 80 +------ imdclient/IMDProtocol.py | 64 ++++++ imdclient/IMDServer.py | 357 ++++++++++++++++++++++++++++++ imdclient/__init__.py | 2 +- imdclient/tests/test_imdclient.py | 18 +- 5 files changed, 445 insertions(+), 76 deletions(-) create mode 100644 imdclient/IMDServer.py diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index a0f71dd..68a59d5 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -216,7 +216,9 @@ def _connect_to_server(self, host, port, socket_bufsize): # /proc/sys/net/core/rmem_default # Max (linux): # /proc/sys/net/core/rmem_max - conn.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, socket_bufsize) + conn.setsockopt( + socket.SOL_SOCKET, socket.SO_RCVBUF, socket_bufsize + ) try: logger.debug(f"IMDClient: Connecting to {host}:{port}") conn.connect((host, port)) @@ -787,7 +789,9 @@ def __init__( raise ValueError("pause_empty_proportion must be between 0 and 1") self._pause_empty_proportion = pause_empty_proportion if unpause_empty_proportion < 0 or unpause_empty_proportion > 1: - raise ValueError("unpause_empty_proportion must be between 0 and 1") + raise ValueError( + "unpause_empty_proportion must be between 0 and 1" + ) self._unpause_empty_proportion = unpause_empty_proportion if buffer_size <= 0: @@ -854,7 +858,9 @@ def wait_for_space(self): logger.debug("IMDProducer: Noticing consumer finished") raise EOFError except Exception as e: - logger.debug(f"IMDProducer: Error waiting for space in buffer: {e}") + logger.debug( + f"IMDProducer: Error waiting for space in buffer: {e}" + ) def pop_empty_imdframe(self): logger.debug("IMDProducer: Getting empty frame") @@ -900,7 +906,9 @@ def pop_full_imdframe(self): imdf = self._full_q.get() else: with self._full_imdf_avail: - while self._full_q.qsize() == 0 and not self._producer_finished: + while ( + self._full_q.qsize() == 0 and not self._producer_finished + ): self._full_imdf_avail.wait() if self._producer_finished and self._full_q.qsize() == 0: @@ -929,67 +937,3 @@ def notify_consumer_finished(self): with self._empty_imdf_avail: # noop if producer isn't waiting self._empty_imdf_avail.notify() - - -class IMDFrame: - def __init__(self, n_atoms, imdsinfo): - if imdsinfo.time: - self.time = 0.0 - self.dt = 0.0 - self.step = 0.0 - else: - self.time = None - self.dt = None - self.step = None - if imdsinfo.energies: - self.energies = { - "step": 0, - "temperature": 0.0, - "total_energy": 0.0, - "potential_energy": 0.0, - "van_der_walls_energy": 0.0, - "coulomb_energy": 0.0, - "bonds_energy": 0.0, - "angles_energy": 0.0, - "dihedrals_energy": 0.0, - "improper_dihedrals_energy": 0.0, - } - else: - self.energies = None - if imdsinfo.box: - self.box = np.empty((3, 3), dtype=np.float32) - else: - self.box = None - if imdsinfo.positions: - self.positions = np.empty((n_atoms, 3), dtype=np.float32) - else: - self.positions = None - if imdsinfo.velocities: - self.velocities = np.empty((n_atoms, 3), dtype=np.float32) - else: - self.velocities = None - if imdsinfo.forces: - self.forces = np.empty((n_atoms, 3), dtype=np.float32) - else: - self.forces = None - - -def imdframe_memsize(n_atoms, imdsinfo) -> int: - """ - Calculate the memory size of an IMDFrame in bytes - """ - memsize = 0 - if imdsinfo.time: - memsize += 8 * 3 - if imdsinfo.energies: - memsize += 4 * 10 - if imdsinfo.box: - memsize += 4 * 9 - if imdsinfo.positions: - memsize += 4 * 3 * n_atoms - if imdsinfo.velocities: - memsize += 4 * 3 * n_atoms - if imdsinfo.forces: - memsize += 4 * 3 * n_atoms - - return memsize diff --git a/imdclient/IMDProtocol.py b/imdclient/IMDProtocol.py index 68caf81..09e0986 100644 --- a/imdclient/IMDProtocol.py +++ b/imdclient/IMDProtocol.py @@ -163,3 +163,67 @@ def parse_header_bytes(data): type = IMDHeaderType(msg_type) # NOTE: add error checking for invalid packet msg_type here return IMDHeader(type, length) + + +class IMDFrame: + def __init__(self, n_atoms, imdsinfo): + if imdsinfo.time: + self.time = 0.0 + self.dt = 0.0 + self.step = 0 + else: + self.time = None + self.dt = None + self.step = None + if imdsinfo.energies: + self.energies = { + "step": 0, + "temperature": 0.0, + "total_energy": 0.0, + "potential_energy": 0.0, + "van_der_walls_energy": 0.0, + "coulomb_energy": 0.0, + "bonds_energy": 0.0, + "angles_energy": 0.0, + "dihedrals_energy": 0.0, + "improper_dihedrals_energy": 0.0, + } + else: + self.energies = None + if imdsinfo.box: + self.box = np.empty((3, 3), dtype=np.float32) + else: + self.box = None + if imdsinfo.positions: + self.positions = np.empty((n_atoms, 3), dtype=np.float32) + else: + self.positions = None + if imdsinfo.velocities: + self.velocities = np.empty((n_atoms, 3), dtype=np.float32) + else: + self.velocities = None + if imdsinfo.forces: + self.forces = np.empty((n_atoms, 3), dtype=np.float32) + else: + self.forces = None + + +def imdframe_memsize(n_atoms, imdsinfo) -> int: + """ + Calculate the memory size of an IMDFrame in bytes + """ + memsize = 0 + if imdsinfo.time: + memsize += 8 * 3 + if imdsinfo.energies: + memsize += 4 * 10 + if imdsinfo.box: + memsize += 4 * 9 + if imdsinfo.positions: + memsize += 4 * 3 * n_atoms + if imdsinfo.velocities: + memsize += 4 * 3 * n_atoms + if imdsinfo.forces: + memsize += 4 * 3 * n_atoms + + return memsize diff --git a/imdclient/IMDServer.py b/imdclient/IMDServer.py new file mode 100644 index 0000000..7f9ea51 --- /dev/null +++ b/imdclient/IMDServer.py @@ -0,0 +1,357 @@ +import socket +from socket import SOL_SOCKET, SO_REUSEADDR +import threading +from .IMDProtocol import * +from .utils import read_into_buf, sock_contains_data, timeit +import logging +import queue +import time +import numpy as np +from typing import Union, Dict +import signal +import atexit +import sys + +logger = logging.getLogger(__name__) + + +class IMDServer: + + def __init__( + self, + version, + n_atoms, + host="0.0.0.0", + port=8888, + time=True, + box=True, + positions=True, + velocities=True, + forces=True, + listen_timeout=60, + read_timeout=5, + **kwargs, + ): + if version not in IMDVERSIONS: + raise ValueError( + f"IMDServer: Incompatible IMD version. Expected version in {IMDVERSIONS}, got {version}" + ) + if version == 2 and (time or box or velocities or forces): + # TODO: and energies? + raise ValueError("IMDServer: IMDv2 only supports positions") + if (version == 2 and not (positions)) or ( + version == 3 + and not (time or box or positions or velocities or forces) + ): + raise ValueError( + f"IMDServer: No IMD flags turned on for IMDv{version}" + ) + + end = ">" if sys.byteorder == "big" else "<" + + self.sinfo = IMDSessionInfo( + version=version, + endianness=end, + wrapped_coords=False, + time=time, + energies=False, + box=box, + positions=positions, + velocities=velocities, + forces=forces, + ) + self.n_atoms = n_atoms + + # shutdown safety + signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) + try: + import IPython + except ImportError: + has_ipython = False + else: + has_ipython = True + + if has_ipython: + try: + from IPython import get_ipython + + if get_ipython() is not None: + kernel = get_ipython().kernel + kernel.pre_handler_hook = lambda: None + kernel.post_handler_hook = lambda: None + logger.debug("Running in Jupyter") + except NameError: + logger.debug("Running in non-jupyter IPython environment") + + atexit.register(self.stop) + + self._alloc_send_buf() + self._header = bytearray(IMDHEADERSIZE) + self._array_dtype = np.dtype(f"{self.sinfo.endianness}f4") + self._time_fmt = f"{self.sinfo.endianness}ddQ" + + try: + self._listen_socket = socket.socket( + socket.AF_INET, socket.SOCK_STREAM + ) + self._listen_socket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + self._listen_socket.bind((host, port)) + + # for the case of port=0 + self.port = self._listen_socket.getsockname()[1] + self._listen_socket.listen() + + if sock_contains_data(self._listen_socket, listen_timeout): + self._conn, _ = self._listen_socket.accept() + else: + # TODO: abstract into loop + raise TimeoutError( + f"IMDServer: No client within {listen_timeout} seconds" + ) + + if version == 2: + self._send_handshakeV2() + elif version == 3: + self._send_handshakeV3() + + self._conn.settimeout(5) + + self._expect_header(IMDHeaderType.IMD_GO) + + self._conn.settimeout(read_timeout) + except Exception as e: + self.stop() + raise e + + def write_frame(self, imdframe): + + if self.sinfo.time: + if imdframe.time is None: + raise ValueError( + f"IMDServer: Time enabled for session, but frame does not contain time information" + ) + + struct.pack_into( + self._time_fmt, + self._time_view, + 0, + imdframe.dt, + imdframe.time, + imdframe.step, + ) + + if self.sinfo.box: + if imdframe.box is None: + raise ValueError( + f"IMDServer: Box enabled for session, but frame does not contain box information" + ) + + src = np.ascontiguousarray(imdframe.box, dtype=self._array_dtype) + target = np.frombuffer(self._box_view, dtype=self._array_dtype) + target = target.reshape(src.shape) + np.copyto(target, src) + + if self.sinfo.positions: + if imdframe.positions is None: + raise ValueError( + f"IMDServer: Positions enabled for session, but frame does not contain positions information" + ) + + src = np.ascontiguousarray( + imdframe.positions, dtype=self._array_dtype + ) + target = np.frombuffer(self._pos_view, dtype=self._array_dtype) + target = target.reshape(src.shape) + np.copyto(target, src) + + if self.sinfo.velocities: + if imdframe.velocities is None: + raise ValueError( + f"IMDServer: Velocities enabled for session, but frame does not contain velocities information" + ) + + src = np.ascontiguousarray( + imdframe.velocities, dtype=self._array_dtype + ) + target = np.frombuffer(self._vel_view, dtype=self._array_dtype) + target = target.reshape(src.shape) + np.copyto(target, src) + + if self.sinfo.forces: + if imdframe.forces is None: + raise ValueError( + f"IMDServer: Forces enabled for session, but frame does not contain forces information" + ) + + src = np.ascontiguousarray( + imdframe.forces, dtype=self._array_dtype + ) + target = np.frombuffer(self._force_view, dtype=self._array_dtype) + target = target.reshape(src.shape) + np.copyto(target, src) + + self._conn.sendall(self._send_buf) + + def _alloc_send_buf(self): + + buf_len = 0 + xvf_size = 12 * self.n_atoms + if self.sinfo.time: + buf_len += IMDHEADERSIZE + IMDTIMEPACKETLENGTH + if self.sinfo.box: + buf_len += IMDHEADERSIZE + IMDBOXPACKETLENGTH + if self.sinfo.positions: + buf_len += IMDHEADERSIZE + xvf_size + if self.sinfo.velocities: + buf_len += IMDHEADERSIZE + xvf_size + if self.sinfo.forces: + buf_len += IMDHEADERSIZE + xvf_size + + self._send_buf = bytearray(buf_len) + offset_bytes = 0 + + if self.sinfo.time: + self._send_buf[offset_bytes : offset_bytes + IMDHEADERSIZE] = ( + create_header_bytes(IMDHeaderType.IMD_TIME, 1) + ) + offset_bytes += IMDHEADERSIZE + self._time_view = memoryview(self._send_buf)[ + offset_bytes : offset_bytes + IMDTIMEPACKETLENGTH + ] + offset_bytes += IMDTIMEPACKETLENGTH + + if self.sinfo.box: + self._send_buf[offset_bytes : offset_bytes + IMDHEADERSIZE] = ( + create_header_bytes(IMDHeaderType.IMD_BOX, 1) + ) + offset_bytes += IMDHEADERSIZE + self._box_view = memoryview(self._send_buf)[ + offset_bytes : offset_bytes + IMDBOXPACKETLENGTH + ] + offset_bytes += IMDBOXPACKETLENGTH + + if self.sinfo.positions: + self._send_buf[offset_bytes : offset_bytes + IMDHEADERSIZE] = ( + create_header_bytes(IMDHeaderType.IMD_FCOORDS, self.n_atoms) + ) + offset_bytes += IMDHEADERSIZE + self._pos_view = memoryview(self._send_buf)[ + offset_bytes : offset_bytes + xvf_size + ] + offset_bytes += xvf_size + + if self.sinfo.velocities: + self._send_buf[offset_bytes : offset_bytes + IMDHEADERSIZE] = ( + create_header_bytes(IMDHeaderType.IMD_VELOCITIES, self.n_atoms) + ) + offset_bytes += IMDHEADERSIZE + self._vel_view = memoryview(self._send_buf)[ + offset_bytes : offset_bytes + xvf_size + ] + offset_bytes += xvf_size + + if self.sinfo.forces: + self._send_buf[offset_bytes : offset_bytes + IMDHEADERSIZE] = ( + create_header_bytes(IMDHeaderType.IMD_FORCES, self.n_atoms) + ) + offset_bytes += IMDHEADERSIZE + self._force_view = memoryview(self._send_buf)[ + offset_bytes : offset_bytes + xvf_size + ] + offset_bytes += xvf_size + + def _send_handshakeV2(self): + header = struct.pack("!i", IMDHeaderType.IMD_HANDSHAKE.value) + header += struct.pack(f"{self.sinfo.endianness}i", 2) + self._conn.sendall(header) + + def _send_handshakeV3(self): + logger.debug(f"InThreadIMDServer: Sending handshake V3") + packet = struct.pack("!i", IMDHeaderType.IMD_HANDSHAKE.value) + packet += struct.pack(f"{self.sinfo.endianness}i", 3) + self._conn.sendall(packet) + + sinfo = struct.pack("!ii", IMDHeaderType.IMD_SESSIONINFO.value, 7) + time = 1 if self.sinfo.time else 0 + energies = 1 if self.sinfo.energies else 0 + box = 1 if self.sinfo.box else 0 + positions = 1 if self.sinfo.positions else 0 + velocities = 1 if self.sinfo.velocities else 0 + forces = 1 if self.sinfo.forces else 0 + wrapped_coords = 0 + sinfo += struct.pack( + f"{self.sinfo.endianness}BBBBBBB", + time, + energies, + box, + positions, + wrapped_coords, + velocities, + forces, + ) + logger.debug(f"IMDServer: Sending session info") + self._conn.sendall(sinfo) + + def _expect_header(self, expected_type, expected_value=None): + header = self._get_header() + + if header.type != expected_type: + raise RuntimeError( + f"IMDProducer: Expected header type {expected_type}, got {header.type}" + ) + # Sometimes we do not care what the value is + if expected_value is not None and header.length != expected_value: + if expected_type in [ + IMDHeaderType.IMD_FCOORDS, + IMDHeaderType.IMD_VELOCITIES, + IMDHeaderType.IMD_FORCES, + ]: + raise RuntimeError( + f"IMDProducer: Expected n_atoms value {expected_value}, got {header.length}. " + + "Ensure you are using the correct topology file." + ) + else: + raise RuntimeError( + f"IMDProducer: Expected header value {expected_value}, got {header.length}" + ) + + def _get_header(self): + self._read(self._header) + return IMDHeader(self._header) + + def _read(self, buf): + """Wraps `read_into_buf` call to give uniform error handling which indicates end of stream""" + try: + read_into_buf(self._conn, buf) + except (ConnectionError, TimeoutError, BlockingIOError, Exception): + # ConnectionError: Server is definitely done sending frames, socket is closed + # TimeoutError: Server is *likely* done sending frames. + # BlockingIOError: Occurs when timeout is 0 in place of a TimeoutError. Server is *likely* done sending frames + # OSError: Occurs when main thread disconnects from the server and closes the socket, but producer thread attempts to read another frame + # Exception: Something unexpected happened + raise EOFError + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + return False + + def stop(self): + print("Stopping") + try: + self._conn.close() + except: + pass + try: + self._listen_socket.close() + except: + pass + + def signal_handler(self, *args, **kwargs): + """Catch SIGINT to allow clean shutdown on CTRL+C.""" + logger.debug("Intercepted signal") + self.stop() + logger.debug("Shutdown success") diff --git a/imdclient/__init__.py b/imdclient/__init__.py index cbd38be..317291f 100644 --- a/imdclient/__init__.py +++ b/imdclient/__init__.py @@ -2,8 +2,8 @@ IMDClient """ -# Don't import IMDReader here, eventually it may be moved to a separate package from .IMDClient import IMDClient +from .IMDServer import IMDServer from importlib.metadata import version __version__ = version("imdclient") diff --git a/imdclient/tests/test_imdclient.py b/imdclient/tests/test_imdclient.py index 656d08c..83f5c40 100644 --- a/imdclient/tests/test_imdclient.py +++ b/imdclient/tests/test_imdclient.py @@ -12,8 +12,8 @@ COORDINATES_H5MD, ) -from imdclient.IMDClient import imdframe_memsize, IMDClient -from imdclient.IMDProtocol import IMDHeaderType +from imdclient.IMDClient import IMDClient +from imdclient.IMDProtocol import IMDHeaderType, imdframe_memsize from .utils import ( create_default_imdsinfo_v3, ) @@ -178,16 +178,20 @@ def test_continue_after_disconnect(self, universe, imdsinfo, cont): IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont) ) - def test_incorrect_atom_count(self, server_client_incorrect_atoms, universe): + def test_incorrect_atom_count( + self, server_client_incorrect_atoms, universe + ): server, client = server_client_incorrect_atoms - + server.send_frame(0) - + with pytest.raises(EOFError) as exc_info: client.get_imdframe() - + error_msg = str(exc_info.value) - assert f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg + assert ( + f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg + ) assert f"got {universe.atoms.n_atoms}" in error_msg assert "Ensure you are using the correct topology file" in error_msg