diff --git a/src/openfe_analysis/reader.py b/src/openfe_analysis/reader.py index b9ac9cd..83cca1c 100644 --- a/src/openfe_analysis/reader.py +++ b/src/openfe_analysis/reader.py @@ -1,4 +1,5 @@ -from typing import Optional +import pathlib +from typing import Literal, Optional import netCDF4 as nc import numpy as np @@ -52,16 +53,20 @@ def _determine_iteration_dt(dataset) -> float: class FEReader(ReaderBase): - """A MDAnalysis Reader for NetCDF files created by + """ + MDAnalysis Reader for NetCDF files created by `openmmtools.multistate.MultiStateReporter` - Looks along a multistate NetCDF file along one of two axes: - - constant state/lambda (varying replica) - - constant replica (varying lambda) + Provides a 1D trajectory along either: + + - constant Hamiltonian state (`index_method="state"`) + - constant replica (`index_method="replica"`) + + selected via the `index` argument. """ - _state_id: Optional[int] - _replica_id: Optional[int] + _multistate_index: Optional[int] + _index_method: Optional[str] _frame_index: int _dataset: nc.Dataset _dataset_owner: bool @@ -70,35 +75,27 @@ class FEReader(ReaderBase): units = {"time": "ps", "length": "nanometer"} - def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, **kwargs): + def __init__( + self, + filename: str | pathlib.Path | nc.Dataset, + *, + index: int, + index_method: Literal["state", "replica"] = "state", + convert_units: bool = True, + **kwargs, + ): """ Parameters ---------- filename : pathlike or nc.Dataset - path to the .nc file + Path to the .nc file or an open Dataset. + index : int + Index of the state or replica to extract. May be negative. + index_method : {"state", "replica"}, default "state" + Whether `index` refers to a Hamiltonian state or a replica. convert_units : bool - convert positions to Angstrom - state_id : Optional[int] - The Hamiltonian state index to extract. Must be defined if - ``replica_id`` is not defined. May be negative (see notes below). - replica_id : Optional[int] - The replica index to extract. Must be defined if ``state_id`` - is not defined. May be negative (see notes below). - - Notes - ----- - A negative index may be passed to either ``state_id`` or - ``replica_id``. This will be interpreted as indexing in reverse - starting from the last state/replica. For example, passing a - value of -2 for ``replica_id`` will select the before last replica. + Convert positions to Angstrom. """ - if not ((state_id is None) ^ (replica_id is None)): - raise ValueError( - "Specify one and only one of state or replica, " - f"got state id={state_id} " - f"replica_id={replica_id}" - ) - super().__init__(filename, convert_units, **kwargs) if isinstance(filename, nc.Dataset): @@ -108,15 +105,18 @@ def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, self._dataset = nc.Dataset(filename) self._dataset_owner = True - # Handle the negative ID case - if state_id is not None and state_id < 0: - state_id = range(self._dataset.dimensions["state"].size)[state_id] + if index_method not in {"state", "replica"}: + raise ValueError(f"index_method must be 'state' or 'replica', got {index_method}") + + self._index_method = index_method - if replica_id is not None and replica_id < 0: - replica_id = range(self._dataset.dimensions["replica"].size)[replica_id] + # Handle the negative ID case + if index_method == "state": + size = self._dataset.dimensions["state"].size + else: + size = self._dataset.dimensions["replica"].size - self._state_id = state_id - self._replica_id = replica_id + self._multistate_index = index % size self._n_atoms = self._dataset.dimensions["atom"].size self.ts = Timestep(self._n_atoms) @@ -131,6 +131,10 @@ def _format_hint(thing) -> bool: # can pass raw nc datasets through to reduce open/close operations return isinstance(thing, nc.Dataset) + @property + def multistate_index(self) -> int: + return self._multistate_index + @property def n_atoms(self) -> int: return self._n_atoms @@ -139,6 +143,10 @@ def n_atoms(self) -> int: def n_frames(self) -> int: return len(self._frames) + @property + def index_method(self) -> str: + return self._index_method + @staticmethod def parse_n_atoms(filename, **kwargs) -> int: with nc.Dataset(filename) as ds: @@ -153,17 +161,19 @@ def _read_next_timestep(self, ts=None) -> Timestep: def _read_frame(self, frame: int) -> Timestep: self._frame_index = frame - if self._state_id is not None: + frame = self._frames[self._frame_index] + + if self._index_method == "state": rep = multistate._state_to_replica( - self._dataset, self._state_id, self._frames[self._frame_index] + self._dataset, + self._multistate_index, + frame, ) else: - rep = self._replica_id + rep = self._multistate_index - pos = multistate._replica_positions_at_frame( - self._dataset, rep, self._frames[self._frame_index] - ) - dim = multistate._get_unitcell(self._dataset, rep, self._frames[self._frame_index]) + pos = multistate._replica_positions_at_frame(self._dataset, rep, frame) + dim = multistate._get_unitcell(self._dataset, rep, frame) if pos is None: errmsg = ( diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 9684bb2..2f5774a 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -36,7 +36,8 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers u = mda.Universe( top, trj, - state_id=state, + index=state, + index_method="state", format=FEReader, ) prot = u.select_atoms("protein and name CA") diff --git a/src/openfe_analysis/tests/test_reader.py b/src/openfe_analysis/tests/test_reader.py index c8d4bb2..da5c6b2 100644 --- a/src/openfe_analysis/tests/test_reader.py +++ b/src/openfe_analysis/tests/test_reader.py @@ -44,7 +44,7 @@ def test_determine_position_indices_warns_for_old_nc(tmp_path): def test_universe_creation(simulation_nc, hybrid_system_pdb): - u = mda.Universe(hybrid_system_pdb, simulation_nc, format=FEReader, state_id=0) + u = mda.Universe(hybrid_system_pdb, simulation_nc, format=FEReader, index=0) # Check that a Universe exists assert u @@ -92,7 +92,7 @@ def test_universe_creation(simulation_nc, hybrid_system_pdb): def test_universe_from_nc_file(simulation_skipped_nc, hybrid_system_skipped_pdb): with nc.Dataset(simulation_skipped_nc) as ds: - u = mda.Universe(hybrid_system_skipped_pdb, ds, format="MultiStateReporter", state_id=0) + u = mda.Universe(hybrid_system_skipped_pdb, ds, format="MultiStateReporter", index=0) assert u assert len(u.atoms) == 9178 @@ -105,7 +105,7 @@ def test_universe_creation_noconversion(simulation_skipped_nc, hybrid_system_ski hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, - state_id=0, + index=0, convert_units=False, ) assert u.trajectory.ts.frame == 0 @@ -124,20 +124,23 @@ def test_universe_creation_noconversion(simulation_skipped_nc, hybrid_system_ski def test_fereader_negative_state(simulation_skipped_nc, hybrid_system_skipped_pdb): - u = mda.Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, state_id=-1) + u = mda.Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=-1) - assert u.trajectory._state_id == 10 - assert u.trajectory._replica_id is None + assert u.trajectory._multistate_index == 10 u.trajectory.close() def test_fereader_negative_replica(simulation_skipped_nc, hybrid_system_skipped_pdb): u = mda.Universe( - hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, replica_id=-2 + hybrid_system_skipped_pdb, + simulation_skipped_nc, + format=FEReader, + index=-2, + index_method="replica", ) - assert u.trajectory._state_id is None - assert u.trajectory._replica_id == 9 + assert u.trajectory._multistate_index == 9 + assert u.trajectory._index_method == "replica" u.trajectory.close() @@ -145,13 +148,13 @@ def test_fereader_negative_replica(simulation_skipped_nc, hybrid_system_skipped_ def test_fereader_replica_state_id_error( simulation_skipped_nc, hybrid_system_skipped_pdb, rep_id, state_id ): - with pytest.raises(ValueError, match="Specify one and only one"): + with pytest.raises(ValueError, match="index_method must be 'state'"): _ = mda.Universe( hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, - state_id=state_id, - replica_id=rep_id, + index=0, + index_method="wrong", ) @@ -162,7 +165,8 @@ def test_simulation_skipped_nc(simulation_skipped_nc, hybrid_system_skipped_pdb) hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, - replica_id=0, + index=0, + index_method="replica", ) # Wrap all atoms inside the simulation box diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index b9022ce..a9e8b09 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -111,7 +111,7 @@ def test_multichain_rmsd_shifting(simulation_skipped_nc, hybrid_system_skipped_p u = mda.Universe( hybrid_system_skipped_pdb, simulation_skipped_nc, - state_id=0, + index=0, format=FEReader, ) prot = u.select_atoms("protein") diff --git a/src/openfe_analysis/tests/test_transformations.py b/src/openfe_analysis/tests/test_transformations.py index b24959d..120700b 100644 --- a/src/openfe_analysis/tests/test_transformations.py +++ b/src/openfe_analysis/tests/test_transformations.py @@ -17,7 +17,7 @@ def universe(hybrid_system_skipped_pdb, simulation_skipped_nc): hybrid_system_skipped_pdb, simulation_skipped_nc, format="MultiStateReporter", - state_id=0, + index=0, ) yield u u.trajectory.close() @@ -54,7 +54,7 @@ def test_nojump(hybrid_system_pdb, simulation_nc): hybrid_system_pdb, simulation_nc, format="MultiStateReporter", - state_id=2, + index=2, ) # find frame where protein would teleport across boundary and check it prot = universe.select_atoms("protein and name CA") diff --git a/src/openfe_analysis/tests/utils/test_multistate.py b/src/openfe_analysis/tests/utils/test_multistate.py index 4b4cdf3..cb2d8db 100644 --- a/src/openfe_analysis/tests/utils/test_multistate.py +++ b/src/openfe_analysis/tests/utils/test_multistate.py @@ -7,9 +7,11 @@ from openfe_analysis import __version__ from openfe_analysis.utils.multistate import ( _create_new_dataset, + _determine_position_indices, _get_unitcell, _replica_positions_at_frame, _state_to_replica, + trajectory_from_multistate, ) @@ -41,6 +43,17 @@ def test_replica_positions_at_frame(dataset): ) +def test_determine_position_indices_inconsistent(monkeypatch, dataset): + # Force np.diff to return inconsistent spacing + def fake_diff(x): + return np.array([1, 2, 1]) + + monkeypatch.setattr(np, "diff", fake_diff) + + with pytest.raises(ValueError, match="consistent frame rate"): + _determine_position_indices(dataset) + + def test_create_new_dataset(tmp_path): file_path = tmp_path / "foo.nc" with _create_new_dataset(file_path, 100, title="bar") as ds: @@ -91,3 +104,80 @@ def test_simulation_skipped_nc_no_positions_box_vectors_frame1( ): assert _get_unitcell(skipped_dataset, 1, 1) is None assert skipped_dataset.variables["positions"][1][0].mask.all() + + +def test_trajectory_invalid_index_method(tmp_path): + dummy_input = tmp_path / "dummy.nc" + dummy_output = tmp_path / "out.nc" + + # Create minimal NetCDF + ds = nc.Dataset(dummy_input, "w", format="NETCDF3_64BIT_OFFSET") + ds.createDimension("atom", 1) + ds.createDimension("frame", 1) + pos = ds.createVariable("positions", "f4", ("frame", "atom")) + pos[:] = 0.0 + ds.close() + + with pytest.raises(ValueError, match="index_method must be 'state' or 'replica'"): + trajectory_from_multistate(dummy_input, dummy_output, index=0, index_method="foo") + + +def test_trajectory_frame_without_positions(tmp_path): + dummy_input = tmp_path / "dummy.nc" + dummy_output = tmp_path / "out.nc" + + # Minimal NetCDF file + with nc.Dataset(dummy_input, "w", format="NETCDF4") as ds: + ds.createDimension("frame", 2) # at least 2 frames + ds.createDimension("replica", 1) + ds.createDimension("atom", 1) + ds.createDimension("spatial", 3) + ds.createDimension("iteration", 2) # at least 2 iterations + + positions = ds.createVariable("positions", "f4", ("frame", "replica", "atom", "spatial")) + positions.units = "nanometer" + positions[:] = np.ma.masked # All positions masked + + # Expect RuntimeError due to missing positions + with pytest.raises(RuntimeError, match="Frame without positions encountered"): + trajectory_from_multistate(dummy_input, dummy_output, index=0, index_method="replica") + + +def test_trajectory_success(tmp_path): + dummy_input = tmp_path / "dummy.nc" + dummy_output = tmp_path / "out.nc" + + # Minimal valid NetCDF with positions, box vectors, and iteration dimension + ds = nc.Dataset(dummy_input, "w", format="NETCDF3_64BIT_OFFSET") + ds.createDimension("atom", 2) + ds.createDimension("frame", 2) + ds.createDimension("replica", 2) + ds.createDimension("state", 2) + ds.createDimension("spatial", 3) + ds.createDimension("iteration", 2) # Added for _determine_position_indices + + # positions: frame x replica x atom x spatial + pos = ds.createVariable("positions", "f4", ("frame", "replica", "atom", "spatial")) + pos.units = "nanometer" + pos[:] = np.zeros((2, 2, 2, 3), dtype=np.float32) + + # box_vectors: frame x replica x 3 x 3 + bv = ds.createVariable("box_vectors", "f8", ("frame", "replica", "spatial", "spatial")) + bv.units = "nanometer" + bv[:] = np.tile(np.eye(3), (2, 2, 1, 1)) + + # states: frame x replica + st = ds.createVariable("states", "i4", ("frame", "replica")) + st[:] = np.array([[0, 1], [0, 1]], dtype=np.int32) # replica 0->state 0, replica1->state1 + + ds.close() + + # Call function for replica extraction + trajectory_from_multistate(dummy_input, dummy_output, index=1, index_method="replica") + + # Check output file exists and contains positions + out_ds = nc.Dataset(dummy_output, "r") + assert out_ds.variables["coordinates"].shape == (2, 2, 3) + assert out_ds.variables["cell_lengths"].shape == (2, 3) + assert out_ds.variables["cell_angles"].shape == (2, 3) + out_ds.close() diff --git a/src/openfe_analysis/utils/multistate.py b/src/openfe_analysis/utils/multistate.py index f885e2b..3816889 100644 --- a/src/openfe_analysis/utils/multistate.py +++ b/src/openfe_analysis/utils/multistate.py @@ -1,6 +1,6 @@ import warnings from pathlib import Path -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import netCDF4 as nc import numpy as np @@ -213,14 +213,12 @@ def _get_unitcell( def trajectory_from_multistate( input_file: Path, output_file: Path, - state_number: Optional[int] = None, - replica_number: Optional[int] = None, + index: int, + index_method: Literal["state", "replica"] = "state", ) -> None: """ - Extract a state's trajectory (in an AMBER compliant format) - from a MultiState sampler generated NetCDF file. - - Either a state or replica index must be supplied, but not both! + Extract a 1D trajectory (in an AMBER compliant format) from a MultiState + sampler generated NetCDF file. Parameters ---------- @@ -228,54 +226,52 @@ def trajectory_from_multistate( Path to the input MultiState sampler generated NetCDF file. output_file : path.Pathlib Path to the AMBER-style NetCDF trajectory to be written. - state_number : int, optional - Index of the state to write out to the trajectory. - replica_number : int, optional - Index of the replica to write out + index : int + Index of the state or replica to extract. May be negative. + index_method : {"state", "replica"}, default "state" + Whether `index` refers to a Hamiltonian state or a replica. """ - if not ((state_number is None) ^ (replica_number is None)): - raise ValueError( - "Supply either state or replica number, " - f"got state_number={state_number} " - f"and replica_number={replica_number}" - ) + if index_method not in {"state", "replica"}: + raise ValueError(f"index_method must be 'state' or 'replica', got {index_method}") # Open MultiState NC file and get number of atoms and frames multistate = nc.Dataset(input_file, "r") n_atoms = len(multistate.variables["positions"][0][0]) - n_replicas = len(multistate.variables["positions"][0]) frame_list = _determine_position_indices(multistate) n_frames = len(frame_list) - # Sanity check - if state_number is not None and (state_number + 1 > n_replicas): - # Note this works for now, but when we have more states - # than replicas (e.g. SAMS) this won't really work - errmsg = "State does not exist" - raise ValueError(errmsg) + # Normalize index (handles negatives) + if index_method == "state": + size = multistate.dimensions["state"].size + else: + size = multistate.dimensions["replica"].size + + index = index % size # Create output AMBER NetCDF convention file traj = _create_new_dataset( - output_file, n_atoms, title=f"state {state_number} trajectory from {input_file}" + output_file, + n_atoms, + title=f"{index_method} {index} trajectory from {input_file}", ) - replica_id: int = -1 - if replica_number is not None: - replica_id = replica_number + replica_id: int = index if index_method == "replica" else -1 # Loopy de loop over n_frames so that the new Dataset # is just 0 -> n_frames for frame in range(n_frames): - if state_number is not None: - replica_id = _state_to_replica(multistate, state_number, frame_list[frame]) + if index_method == "state": + replica_id = _state_to_replica(multistate, index, frame_list[frame]) + + pos = _replica_positions_at_frame(multistate, replica_id, frame_list[frame]) + if pos is None: + raise RuntimeError("Frame without positions encountered") + + traj.variables["coordinates"][frame] = pos.to("angstrom").m - traj.variables["coordinates"][frame] = ( - _replica_positions_at_frame(multistate, replica_id, frame_list[frame]).to("angstrom").m - ) unitcell = _get_unitcell(multistate, replica_id, frame_list[frame]) traj.variables["cell_lengths"][frame] = unitcell[:3] traj.variables["cell_angles"][frame] = unitcell[3:] - # Make sure to clean up when you are done multistate.close() traj.close()