Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
4b257b2
Fix stab at fixing multi chain RMSD analysis
hannahbaumann Dec 18, 2025
236ff72
Some updates
hannahbaumann Dec 18, 2025
d274b0c
Add tests
hannahbaumann Dec 18, 2025
f3634dd
Some fixes
hannahbaumann Dec 19, 2025
e92adb3
Add another test
hannahbaumann Dec 19, 2025
b528ca2
Move some tests to use skipped smaller data
hannahbaumann Jan 16, 2026
a477bc1
Test out zenodo dealings
hannahbaumann Jan 16, 2026
ad84082
Try to improbe speed
hannahbaumann Jan 16, 2026
8ba8087
Try removing locking
hannahbaumann Jan 16, 2026
ead7951
Run downloads before the testing to have a single download for all th…
hannahbaumann Jan 19, 2026
f898a35
add import pooch
hannahbaumann Jan 19, 2026
c675a5c
Test out more
hannahbaumann Jan 19, 2026
88e456d
Ensure datasets get closed
hannahbaumann Jan 19, 2026
73a8e4d
Move to per test download again
hannahbaumann Jan 19, 2026
43aaca2
Remove commented out lines
hannahbaumann Jan 21, 2026
c165525
Test out adding an extra slash
hannahbaumann Jan 21, 2026
5f17770
Switch to all version doi
hannahbaumann Jan 21, 2026
c28286e
Download url directly
hannahbaumann Jan 21, 2026
197b6ba
Small fix
hannahbaumann Jan 21, 2026
b45390a
Change url
hannahbaumann Jan 21, 2026
1d70936
Add missing s
hannahbaumann Jan 21, 2026
20084c3
Switch to api url
hannahbaumann Jan 21, 2026
1a1c916
Revert to old cli
hannahbaumann Jan 22, 2026
59c7392
Update cli.py
hannahbaumann Jan 22, 2026
5e135ab
Update cli.py
hannahbaumann Jan 22, 2026
a9a8780
Update tests for new results
hannahbaumann Jan 23, 2026
92af45b
Change shift to enable other boxes
hannahbaumann Jan 23, 2026
8ea3585
Update multichain code
hannahbaumann Jan 26, 2026
220d504
Add ligand in shifting
hannahbaumann Jan 26, 2026
8c44cb2
USe new shift class instead of old minimiser since that one is no lon…
hannahbaumann Jan 26, 2026
d13495f
Update some tests
hannahbaumann Jan 26, 2026
c34c97c
Update conftest
hannahbaumann Jan 26, 2026
0161673
Update to v2
hannahbaumann Jan 26, 2026
9b6ca69
Update tests
hannahbaumann Jan 26, 2026
1d5c849
Update rmsd test, currently large rmsd till rmsd fix comes in
hannahbaumann Jan 26, 2026
f4e88e2
Make last test pass
hannahbaumann Jan 26, 2026
bd0c8ee
Switch to zenodo fetch
hannahbaumann Jan 26, 2026
ba4c912
remove lines
hannahbaumann Jan 26, 2026
157c02f
Update tests with large errors multichain failure
hannahbaumann Jan 27, 2026
98ea023
Apply suggestion from @hannahbaumann
hannahbaumann Jan 28, 2026
c5b2d70
Reuse zenodo specification
hannahbaumann Jan 28, 2026
54576ab
reorder install
hannahbaumann Jan 28, 2026
3aa52a5
Small fix
hannahbaumann Jan 28, 2026
ff6991a
Remove flaky retries
hannahbaumann Jan 28, 2026
7a30f69
Small fix
hannahbaumann Jan 28, 2026
ac1fe7b
Merge in the fix flakyness PR and update tests
hannahbaumann Jan 28, 2026
67a0913
Add wrapping to get positions to be greater than 0
hannahbaumann Jan 29, 2026
e706b11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2026
7277f20
Replace state_id and replica_id with index and view
hannahbaumann Jan 30, 2026
6c1a6c5
Apply suggestion from @hannahbaumann
hannahbaumann Feb 2, 2026
f4637f4
Remove unnecessary make_whole
hannahbaumann Feb 2, 2026
ae19d9e
Address review comments
hannahbaumann Feb 2, 2026
b215370
Small fix
hannahbaumann Feb 2, 2026
e7d6935
Merge branch 'main' into fix_rmsd_multichain
hannahbaumann Feb 3, 2026
deb5126
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
fa3227e
Merge branch 'main' into fix_rmsd_multichain
hannahbaumann Feb 6, 2026
40fd09c
Merge branch 'main' into reader_refactor
hannahbaumann Feb 6, 2026
9986820
also update the handling of indices of state vs replica in the trajec…
hannahbaumann Feb 6, 2026
3a531c2
Merge branch 'main' into reader_refactor
hannahbaumann Feb 6, 2026
e153d51
small fixes
hannahbaumann Feb 6, 2026
43eb039
Small fix
hannahbaumann Feb 6, 2026
73fe2ee
Merge branch 'main' into fix_rmsd_multichain
hannahbaumann Feb 6, 2026
7be4c53
Update tests
hannahbaumann Feb 10, 2026
68a2aab
Apply suggestion from @hannahbaumann
hannahbaumann Feb 10, 2026
198156b
Update src/openfe_analysis/transformations.py
hannahbaumann Feb 16, 2026
6f85466
Update src/openfe_analysis/transformations.py
hannahbaumann Feb 16, 2026
1d19473
Update src/openfe_analysis/transformations.py
hannahbaumann Feb 16, 2026
6cb52af
Update src/openfe_analysis/rmsd.py
hannahbaumann Feb 16, 2026
2ccc7a8
Update src/openfe_analysis/rmsd.py
hannahbaumann Feb 16, 2026
24163e5
Modify test for closest image shift
hannahbaumann Feb 16, 2026
5994774
Small fix
hannahbaumann Feb 16, 2026
db0f27d
Merge branch 'fix_rmsd_multichain' into reader_refactor
hannahbaumann Feb 16, 2026
cd83ce8
Merge branch 'main' into reader_refactor
hannahbaumann Feb 16, 2026
6b9578a
Update src/openfe_analysis/rmsd.py
hannahbaumann Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 54 additions & 44 deletions src/openfe_analysis/reader.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import pathlib
from typing import Literal, Optional

import netCDF4 as nc
import numpy as np
Expand Down Expand Up @@ -52,16 +53,20 @@ def _determine_iteration_dt(dataset) -> float:


class FEReader(ReaderBase):
"""A MDAnalysis Reader for NetCDF files created by
"""
MDAnalysis Reader for NetCDF files created by
`openmmtools.multistate.MultiStateReporter`

Looks along a multistate NetCDF file along one of two axes:
- constant state/lambda (varying replica)
- constant replica (varying lambda)
Provides a 1D trajectory along either:

- constant Hamiltonian state (`index_method="state"`)
- constant replica (`index_method="replica"`)

selected via the `index` argument.
"""

_state_id: Optional[int]
_replica_id: Optional[int]
_multistate_index: Optional[int]
_index_method: Optional[str]
_frame_index: int
_dataset: nc.Dataset
_dataset_owner: bool
Expand All @@ -70,35 +75,27 @@ class FEReader(ReaderBase):

units = {"time": "ps", "length": "nanometer"}

def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, **kwargs):
def __init__(
self,
filename: str | pathlib.Path | nc.Dataset,
*,
index: int,
index_method: Literal["state", "replica"] = "state",
convert_units: bool = True,
**kwargs,
):
"""
Parameters
----------
filename : pathlike or nc.Dataset
path to the .nc file
Path to the .nc file or an open Dataset.
index : int
Index of the state or replica to extract. May be negative.
index_method : {"state", "replica"}, default "state"
Whether `index` refers to a Hamiltonian state or a replica.
convert_units : bool
convert positions to Angstrom
state_id : Optional[int]
The Hamiltonian state index to extract. Must be defined if
``replica_id`` is not defined. May be negative (see notes below).
replica_id : Optional[int]
The replica index to extract. Must be defined if ``state_id``
is not defined. May be negative (see notes below).

Notes
-----
A negative index may be passed to either ``state_id`` or
``replica_id``. This will be interpreted as indexing in reverse
starting from the last state/replica. For example, passing a
value of -2 for ``replica_id`` will select the before last replica.
Convert positions to Angstrom.
"""
if not ((state_id is None) ^ (replica_id is None)):
raise ValueError(
"Specify one and only one of state or replica, "
f"got state id={state_id} "
f"replica_id={replica_id}"
)

super().__init__(filename, convert_units, **kwargs)

if isinstance(filename, nc.Dataset):
Expand All @@ -108,15 +105,18 @@ def __init__(self, filename, convert_units=True, state_id=None, replica_id=None,
self._dataset = nc.Dataset(filename)
self._dataset_owner = True

# Handle the negative ID case
if state_id is not None and state_id < 0:
state_id = range(self._dataset.dimensions["state"].size)[state_id]
if index_method not in {"state", "replica"}:
raise ValueError(f"index_method must be 'state' or 'replica', got {index_method}")

self._index_method = index_method

if replica_id is not None and replica_id < 0:
replica_id = range(self._dataset.dimensions["replica"].size)[replica_id]
# Handle the negative ID case
if index_method == "state":
size = self._dataset.dimensions["state"].size
else:
size = self._dataset.dimensions["replica"].size

self._state_id = state_id
self._replica_id = replica_id
self._multistate_index = index % size

self._n_atoms = self._dataset.dimensions["atom"].size
self.ts = Timestep(self._n_atoms)
Expand All @@ -131,6 +131,10 @@ def _format_hint(thing) -> bool:
# can pass raw nc datasets through to reduce open/close operations
return isinstance(thing, nc.Dataset)

@property
def multistate_index(self) -> int:
return self._multistate_index

@property
def n_atoms(self) -> int:
return self._n_atoms
Expand All @@ -139,6 +143,10 @@ def n_atoms(self) -> int:
def n_frames(self) -> int:
return len(self._frames)

@property
def index_method(self) -> str:
return self._index_method

@staticmethod
def parse_n_atoms(filename, **kwargs) -> int:
with nc.Dataset(filename) as ds:
Expand All @@ -153,17 +161,19 @@ def _read_next_timestep(self, ts=None) -> Timestep:
def _read_frame(self, frame: int) -> Timestep:
self._frame_index = frame

if self._state_id is not None:
frame = self._frames[self._frame_index]

if self._index_method == "state":
rep = multistate._state_to_replica(
self._dataset, self._state_id, self._frames[self._frame_index]
self._dataset,
self._multistate_index,
frame,
)
else:
rep = self._replica_id
rep = self._multistate_index

pos = multistate._replica_positions_at_frame(
self._dataset, rep, self._frames[self._frame_index]
)
dim = multistate._get_unitcell(self._dataset, rep, self._frames[self._frame_index])
pos = multistate._replica_positions_at_frame(self._dataset, rep, frame)
dim = multistate._get_unitcell(self._dataset, rep, frame)

if pos is None:
errmsg = (
Expand Down
3 changes: 2 additions & 1 deletion src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers
u = mda.Universe(
top,
trj,
state_id=state,
index=state,
index_method="state",
format=FEReader,
)
prot = u.select_atoms("protein and name CA")
Expand Down
30 changes: 17 additions & 13 deletions src/openfe_analysis/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_determine_position_indices_warns_for_old_nc(tmp_path):


def test_universe_creation(simulation_nc, hybrid_system_pdb):
u = mda.Universe(hybrid_system_pdb, simulation_nc, format=FEReader, state_id=0)
u = mda.Universe(hybrid_system_pdb, simulation_nc, format=FEReader, index=0)

# Check that a Universe exists
assert u
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_universe_creation(simulation_nc, hybrid_system_pdb):

def test_universe_from_nc_file(simulation_skipped_nc, hybrid_system_skipped_pdb):
with nc.Dataset(simulation_skipped_nc) as ds:
u = mda.Universe(hybrid_system_skipped_pdb, ds, format="MultiStateReporter", state_id=0)
u = mda.Universe(hybrid_system_skipped_pdb, ds, format="MultiStateReporter", index=0)

assert u
assert len(u.atoms) == 9178
Expand All @@ -105,7 +105,7 @@ def test_universe_creation_noconversion(simulation_skipped_nc, hybrid_system_ski
hybrid_system_skipped_pdb,
simulation_skipped_nc,
format=FEReader,
state_id=0,
index=0,
convert_units=False,
)
assert u.trajectory.ts.frame == 0
Expand All @@ -124,34 +124,37 @@ def test_universe_creation_noconversion(simulation_skipped_nc, hybrid_system_ski


def test_fereader_negative_state(simulation_skipped_nc, hybrid_system_skipped_pdb):
u = mda.Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, state_id=-1)
u = mda.Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=-1)

assert u.trajectory._state_id == 10
assert u.trajectory._replica_id is None
assert u.trajectory._multistate_index == 10
u.trajectory.close()


def test_fereader_negative_replica(simulation_skipped_nc, hybrid_system_skipped_pdb):
u = mda.Universe(
hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, replica_id=-2
hybrid_system_skipped_pdb,
simulation_skipped_nc,
format=FEReader,
index=-2,
index_method="replica",
)

assert u.trajectory._state_id is None
assert u.trajectory._replica_id == 9
assert u.trajectory._multistate_index == 9
assert u.trajectory._index_method == "replica"
u.trajectory.close()


@pytest.mark.parametrize("rep_id, state_id", [[None, None], [1, 1]])
def test_fereader_replica_state_id_error(
simulation_skipped_nc, hybrid_system_skipped_pdb, rep_id, state_id
):
with pytest.raises(ValueError, match="Specify one and only one"):
with pytest.raises(ValueError, match="index_method must be 'state'"):
_ = mda.Universe(
hybrid_system_skipped_pdb,
simulation_skipped_nc,
format=FEReader,
state_id=state_id,
replica_id=rep_id,
index=0,
index_method="wrong",
)


Expand All @@ -162,7 +165,8 @@ def test_simulation_skipped_nc(simulation_skipped_nc, hybrid_system_skipped_pdb)
hybrid_system_skipped_pdb,
simulation_skipped_nc,
format=FEReader,
replica_id=0,
index=0,
index_method="replica",
)

# Wrap all atoms inside the simulation box
Expand Down
2 changes: 1 addition & 1 deletion src/openfe_analysis/tests/test_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_multichain_rmsd_shifting(simulation_skipped_nc, hybrid_system_skipped_p
u = mda.Universe(
hybrid_system_skipped_pdb,
simulation_skipped_nc,
state_id=0,
index=0,
format=FEReader,
)
prot = u.select_atoms("protein")
Expand Down
4 changes: 2 additions & 2 deletions src/openfe_analysis/tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def universe(hybrid_system_skipped_pdb, simulation_skipped_nc):
hybrid_system_skipped_pdb,
simulation_skipped_nc,
format="MultiStateReporter",
state_id=0,
index=0,
)
yield u
u.trajectory.close()
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_nojump(hybrid_system_pdb, simulation_nc):
hybrid_system_pdb,
simulation_nc,
format="MultiStateReporter",
state_id=2,
index=2,
)
# find frame where protein would teleport across boundary and check it
prot = universe.select_atoms("protein and name CA")
Expand Down
90 changes: 90 additions & 0 deletions src/openfe_analysis/tests/utils/test_multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from openfe_analysis import __version__
from openfe_analysis.utils.multistate import (
_create_new_dataset,
_determine_position_indices,
_get_unitcell,
_replica_positions_at_frame,
_state_to_replica,
trajectory_from_multistate,
)


Expand Down Expand Up @@ -41,6 +43,17 @@ def test_replica_positions_at_frame(dataset):
)


def test_determine_position_indices_inconsistent(monkeypatch, dataset):
# Force np.diff to return inconsistent spacing
def fake_diff(x):
return np.array([1, 2, 1])

monkeypatch.setattr(np, "diff", fake_diff)

with pytest.raises(ValueError, match="consistent frame rate"):
_determine_position_indices(dataset)


def test_create_new_dataset(tmp_path):
file_path = tmp_path / "foo.nc"
with _create_new_dataset(file_path, 100, title="bar") as ds:
Expand Down Expand Up @@ -91,3 +104,80 @@ def test_simulation_skipped_nc_no_positions_box_vectors_frame1(
):
assert _get_unitcell(skipped_dataset, 1, 1) is None
assert skipped_dataset.variables["positions"][1][0].mask.all()


def test_trajectory_invalid_index_method(tmp_path):
dummy_input = tmp_path / "dummy.nc"
dummy_output = tmp_path / "out.nc"

# Create minimal NetCDF
ds = nc.Dataset(dummy_input, "w", format="NETCDF3_64BIT_OFFSET")
ds.createDimension("atom", 1)
ds.createDimension("frame", 1)
pos = ds.createVariable("positions", "f4", ("frame", "atom"))
pos[:] = 0.0
ds.close()

with pytest.raises(ValueError, match="index_method must be 'state' or 'replica'"):
trajectory_from_multistate(dummy_input, dummy_output, index=0, index_method="foo")


def test_trajectory_frame_without_positions(tmp_path):
dummy_input = tmp_path / "dummy.nc"
dummy_output = tmp_path / "out.nc"

# Minimal NetCDF file
with nc.Dataset(dummy_input, "w", format="NETCDF4") as ds:
ds.createDimension("frame", 2) # at least 2 frames
ds.createDimension("replica", 1)
ds.createDimension("atom", 1)
ds.createDimension("spatial", 3)
ds.createDimension("iteration", 2) # at least 2 iterations

positions = ds.createVariable("positions", "f4", ("frame", "replica", "atom", "spatial"))
positions.units = "nanometer"
positions[:] = np.ma.masked # All positions masked

# Expect RuntimeError due to missing positions
with pytest.raises(RuntimeError, match="Frame without positions encountered"):
trajectory_from_multistate(dummy_input, dummy_output, index=0, index_method="replica")


def test_trajectory_success(tmp_path):
dummy_input = tmp_path / "dummy.nc"
dummy_output = tmp_path / "out.nc"

# Minimal valid NetCDF with positions, box vectors, and iteration dimension
ds = nc.Dataset(dummy_input, "w", format="NETCDF3_64BIT_OFFSET")
ds.createDimension("atom", 2)
ds.createDimension("frame", 2)
ds.createDimension("replica", 2)
ds.createDimension("state", 2)
ds.createDimension("spatial", 3)
ds.createDimension("iteration", 2) # Added for _determine_position_indices

# positions: frame x replica x atom x spatial
pos = ds.createVariable("positions", "f4", ("frame", "replica", "atom", "spatial"))
pos.units = "nanometer"
pos[:] = np.zeros((2, 2, 2, 3), dtype=np.float32)

# box_vectors: frame x replica x 3 x 3
bv = ds.createVariable("box_vectors", "f8", ("frame", "replica", "spatial", "spatial"))
bv.units = "nanometer"
bv[:] = np.tile(np.eye(3), (2, 2, 1, 1))

# states: frame x replica
st = ds.createVariable("states", "i4", ("frame", "replica"))
st[:] = np.array([[0, 1], [0, 1]], dtype=np.int32) # replica 0->state 0, replica1->state1

ds.close()

# Call function for replica extraction
trajectory_from_multistate(dummy_input, dummy_output, index=1, index_method="replica")

# Check output file exists and contains positions
out_ds = nc.Dataset(dummy_output, "r")
assert out_ds.variables["coordinates"].shape == (2, 2, 3)
assert out_ds.variables["cell_lengths"].shape == (2, 3)
assert out_ds.variables["cell_angles"].shape == (2, 3)
out_ds.close()
Loading