Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,61 @@ def default_config(
raise NotImplementedError(phase)


def select_config(
phase: Phase,
temperature: float,
pressure: float | None,
custom_config: dict[str, SimulationConfig] | None = None,
) -> SimulationConfig:
"""
A helper method to choose the simulation config based on the phase
with the desired temperature and pressure.
If a custom configuration is not available the default will be used.

Args:
phase: The phase of the simulation.
temperature: The temperature [K] at which to run the simulation.
pressure: The pressure [atm] at which to run the simulation
custom_config: The custom simulation configuration for each phase.

Returns:
The simulation configuration for the given phase.
"""
if custom_config is None:
custom_config = {}

try:
config = custom_config[phase]
# edit the config with the desired temperature and pressure
temperature = temperature * openmm.unit.kelvin
pressure = pressure * openmm.unit.atmosphere
for stage in config.equilibrate:
if isinstance(stage, smee.mm.SimulationConfig):
stage.temperature = temperature
stage.pressure = pressure

config.production.temperature = temperature
config.production.pressure = pressure

except KeyError:
config = default_config(phase=phase, temperature=temperature, pressure=pressure)

return config


def _plan_simulations(
entries: list[DataEntry], topologies: dict[str, smee.TensorTopology]
entries: list[DataEntry],
topologies: dict[str, smee.TensorTopology],
simulation_config: dict[str, SimulationConfig] | None = None,
) -> tuple[dict[Phase, _SystemDict], list[dict[str, SimulationKey]]]:
"""Plan the simulations required to compute the properties in a dataset.

Args:
entries: The entries in the dataset.
topologies: The topologies of the molecules present in the dataset, with keys
of mapped SMILES patterns.
simulation_config: The (optional) simulation configuration, should contain
a config for each phase if not provided the default will be used.

Returns:
The systems to simulate and the simulations required to compute each property.
Expand All @@ -428,7 +474,9 @@ def _plan_simulations(

required_sims: dict[str, SimulationKey] = {}

bulk_config = default_config("bulk", entry["temperature"], entry["pressure"])
bulk_config = select_config(
"bulk", entry["temperature"], entry["pressure"], simulation_config
)
max_mols = bulk_config.max_mols

if _REQUIRES_BULK_SIM[data_type]:
Expand Down Expand Up @@ -506,6 +554,7 @@ def _compute_observables(
force_field: smee.TensorForceField,
output_dir: pathlib.Path,
cached_dir: pathlib.Path | None,
simulation_config: dict[str, SimulationConfig] | None = None,
) -> _Observables:
traj_hash = hashlib.sha256(pickle.dumps(key)).hexdigest()
traj_name = f"{phase}-{traj_hash}-frames.msgpack"
Expand All @@ -529,7 +578,12 @@ def _compute_observables(

output_path = output_dir / traj_name

config = default_config(phase, key.temperature, key.pressure)
config = select_config(
phase=phase,
temperature=key.temperature,
pressure=key.pressure,
custom_config=simulation_config,
)
_simulate(system, force_field, config, output_path)

return _Observables(
Expand Down Expand Up @@ -646,6 +700,7 @@ def predict(
cached_dir: pathlib.Path | None = None,
per_type_scales: dict[DataType, float] | None = None,
verbose: bool = False,
simulation_config: dict[str, SimulationConfig] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict the properties in a dataset using molecular simulation, or by reweighting
previous simulation data.
Expand All @@ -661,15 +716,25 @@ def predict(
per_type_scales: The scale factor to apply to each data type. A default of 1.0
will be used for any data type not specified.
verbose: Whether to log additional information.
simulation_config: The (optional) simulation configuration, should contain
a config for each phase if not provided the default will be used.
"""

entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)]

required_simulations, entry_to_simulation = _plan_simulations(entries, topologies)
required_simulations, entry_to_simulation = _plan_simulations(
entries, topologies, simulation_config
)
observables = {
phase: {
key: _compute_observables(
phase, key, system, force_field, output_dir, cached_dir
phase,
key,
system,
force_field,
output_dir,
cached_dir,
simulation_config,
)
for key, system in systems.items()
}
Expand Down Expand Up @@ -736,6 +801,7 @@ def default_closure(
dataset: datasets.Dataset,
per_type_scales: dict[DataType, float] | None = None,
verbose: bool = False,
simulation_config: dict[str, SimulationConfig] | None = None,
) -> descent.optim.ClosureFn:
"""Return a default closure function for training against thermodynamic
properties.
Expand All @@ -747,6 +813,8 @@ def default_closure(
dataset: The dataset to train against.
per_type_scales: The scale factor to apply to each data type.
verbose: Whether to log additional information about predictions.
simulation_config: The (optional) simulation configuration, should contain
a config for each phase if not provided the default will be used.

Returns:
The default closure function.
Expand All @@ -767,6 +835,7 @@ def closure_fn(
None,
per_type_scales,
verbose,
simulation_config,
)
loss, gradient, hessian = ((y_pred - y_ref) ** 2).sum(), None, None

Expand Down
104 changes: 92 additions & 12 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import descent.utils.dataset
from descent.targets.thermo import (
DataEntry,
SimulationConfig,
SimulationKey,
_compute_observables,
_convert_entry_to_system,
Expand All @@ -21,6 +22,7 @@
default_config,
extract_smiles,
predict,
select_config,
)


Expand Down Expand Up @@ -178,17 +180,87 @@ def test_default_config(phase, pressure, expected_n_mols):
)


def test_select_config():
custom_config = {
"bulk": SimulationConfig(
max_mols=1000,
gen_coords=smee.mm.GenerateCoordsConfig(),
equilibrate=[
smee.mm.MinimizationConfig(),
# short NVT equilibration simulation
smee.mm.SimulationConfig(
temperature=300 * openmm.unit.kelvin,
pressure=None,
n_steps=50000,
timestep=2.0 * openmm.unit.femtosecond,
),
smee.mm.SimulationConfig(
temperature=300 * openmm.unit.kelvin,
pressure=1 * openmm.unit.atmosphere,
n_steps=100000,
timestep=2.0 * openmm.unit.femtosecond,
),
],
production=smee.mm.SimulationConfig(
temperature=300 * openmm.unit.kelvin,
pressure=1 * openmm.unit.atmosphere,
n_steps=1000000,
timestep=2.0 * openmm.unit.femtosecond,
),
production_frequency=2000,
)
}
temperature = 298.15 * openmm.unit.kelvin
pressure = 1 * openmm.unit.atmosphere
config = select_config(
phase="bulk",
temperature=temperature.value_in_unit(openmm.unit.kelvin),
pressure=pressure.value_in_unit(openmm.unit.atmosphere),
custom_config=custom_config,
)
# make sure the custom config has been changed to match what was requested
for stage in config.equilibrate:
if isinstance(stage, smee.mm.SimulationConfig):
assert stage.temperature == temperature
assert stage.pressure == pressure

assert config.production.temperature == temperature
assert config.production.pressure == pressure
assert config.max_mols == 1000


@pytest.mark.parametrize(
"max_mols", [pytest.param(256, id="256"), pytest.param(1000, id="1000")]
)
def test_plan_simulations(
mock_density_pure, mock_density_binary, mock_hvap, mock_hmix, mocker
mock_density_pure, mock_density_binary, mock_hvap, mock_hmix, mocker, max_mols
):
topology_co = mocker.Mock()
topology_cco = mocker.Mock()
topology_cccc = mocker.Mock()

topologies = {"CO": topology_co, "CCO": topology_cco, "CCCC": topology_cccc}

# some mock config
custom_config = {
"bulk": SimulationConfig(
max_mols=max_mols,
gen_coords=smee.mm.GenerateCoordsConfig(),
equilibrate=[smee.mm.MinimizationConfig()],
production=smee.mm.SimulationConfig(
temperature=300 * openmm.unit.kelvin,
pressure=None,
n_steps=5000,
timestep=1.0 * openmm.unit.femtosecond,
),
production_frequency=1000,
)
}

required_simulations, entry_to_simulation = _plan_simulations(
[mock_density_pure, mock_density_binary, mock_hvap, mock_hmix], topologies
[mock_density_pure, mock_density_binary, mock_hvap, mock_hmix],
topologies,
custom_config,
)

assert sorted(required_simulations) == ["bulk", "vacuum"]
Expand All @@ -202,25 +274,25 @@ def test_plan_simulations(

expected_cccc_key = SimulationKey(
("CCCC",),
(256,),
(max_mols,),
mock_hvap["temperature"],
mock_hvap["pressure"],
)
expected_co_key = SimulationKey(
("CO",),
(256,),
(max_mols,),
mock_density_pure["temperature"],
mock_density_pure["pressure"],
)
expected_cco_key = SimulationKey(
("CCO",),
(256,),
(max_mols,),
mock_density_binary["temperature"],
mock_density_binary["pressure"],
)
expected_cco_co_key = SimulationKey(
("CCO", "CO"),
(128, 128),
(max_mols / 2, max_mols / 2),
mock_density_binary["temperature"],
mock_density_binary["pressure"],
)
Expand All @@ -234,16 +306,19 @@ def test_plan_simulations(

assert sorted(required_simulations["bulk"]) == sorted(expected_bulk_keys)

assert required_simulations["bulk"][expected_cccc_key].n_copies == [256]
assert required_simulations["bulk"][expected_cccc_key].n_copies == [max_mols]
assert required_simulations["bulk"][expected_cccc_key].topologies == [topology_cccc]

assert required_simulations["bulk"][expected_cco_key].n_copies == [256]
assert required_simulations["bulk"][expected_cco_key].n_copies == [max_mols]
assert required_simulations["bulk"][expected_cco_key].topologies == [topology_cco]

assert required_simulations["bulk"][expected_co_key].n_copies == [256]
assert required_simulations["bulk"][expected_co_key].n_copies == [max_mols]
assert required_simulations["bulk"][expected_co_key].topologies == [topology_co]

assert required_simulations["bulk"][expected_cco_co_key].n_copies == [128, 128]
assert required_simulations["bulk"][expected_cco_co_key].n_copies == [
max_mols / 2,
max_mols / 2,
]
assert required_simulations["bulk"][expected_cco_co_key].topologies == [
topology_cco,
topology_co,
Expand Down Expand Up @@ -334,7 +409,9 @@ def test_compute_observables_reweighted(tmp_cwd, mocker):
expected_path = cached_dir / f"{phase}-{expected_hash}-frames.msgpack"
expected_path.touch()

result = _compute_observables(phase, key, mock_system, mock_ff, tmp_cwd, cached_dir)
result = _compute_observables(
phase, key, mock_system, mock_ff, tmp_cwd, cached_dir, None
)
assert result.mean == mock_result
assert {*result.std} == {*result.mean}

Expand Down Expand Up @@ -375,7 +452,9 @@ def test_compute_observables_simulated(tmp_cwd, mocker):
expected_path = tmp_cwd / f"{phase}-{expected_hash}-frames.msgpack"
expected_path.touch()

result = _compute_observables(phase, key, mock_system, mock_ff, tmp_cwd, cached_dir)
result = _compute_observables(
phase, key, mock_system, mock_ff, tmp_cwd, cached_dir, None
)
assert result == mock_result

mock_simulate.assert_called_once_with(
Expand Down Expand Up @@ -540,6 +619,7 @@ def test_predict(tmp_cwd, mock_density_pure, mocker):
mock_ff,
tmp_cwd,
None,
None,
)

expected_y_ref = torch.tensor([mock_density_pure["value"] * mock_scale])
Expand Down
Loading