diff --git a/descent/targets/thermo.py b/descent/targets/thermo.py index 0457146..8451efb 100644 --- a/descent/targets/thermo.py +++ b/descent/targets/thermo.py @@ -404,8 +404,52 @@ def default_config( raise NotImplementedError(phase) +def select_config( + phase: Phase, + temperature: float, + pressure: float | None, + custom_config: dict[str, SimulationConfig] | None = None, +) -> SimulationConfig: + """ + A helper method to choose the simulation config based on the phase + with the desired temperature and pressure. + If a custom configuration is not available the default will be used. + + Args: + phase: The phase of the simulation. + temperature: The temperature [K] at which to run the simulation. + pressure: The pressure [atm] at which to run the simulation + custom_config: The custom simulation configuration for each phase. + + Returns: + The simulation configuration for the given phase. + """ + if custom_config is None: + custom_config = {} + + try: + config = custom_config[phase] + # edit the config with the desired temperature and pressure + temperature = temperature * openmm.unit.kelvin + pressure = pressure * openmm.unit.atmosphere + for stage in config.equilibrate: + if isinstance(stage, smee.mm.SimulationConfig): + stage.temperature = temperature + stage.pressure = pressure + + config.production.temperature = temperature + config.production.pressure = pressure + + except KeyError: + config = default_config(phase=phase, temperature=temperature, pressure=pressure) + + return config + + def _plan_simulations( - entries: list[DataEntry], topologies: dict[str, smee.TensorTopology] + entries: list[DataEntry], + topologies: dict[str, smee.TensorTopology], + simulation_config: dict[str, SimulationConfig] | None = None, ) -> tuple[dict[Phase, _SystemDict], list[dict[str, SimulationKey]]]: """Plan the simulations required to compute the properties in a dataset. @@ -413,6 +457,8 @@ def _plan_simulations( entries: The entries in the dataset. topologies: The topologies of the molecules present in the dataset, with keys of mapped SMILES patterns. + simulation_config: The (optional) simulation configuration, should contain + a config for each phase if not provided the default will be used. Returns: The systems to simulate and the simulations required to compute each property. @@ -428,7 +474,9 @@ def _plan_simulations( required_sims: dict[str, SimulationKey] = {} - bulk_config = default_config("bulk", entry["temperature"], entry["pressure"]) + bulk_config = select_config( + "bulk", entry["temperature"], entry["pressure"], simulation_config + ) max_mols = bulk_config.max_mols if _REQUIRES_BULK_SIM[data_type]: @@ -506,6 +554,7 @@ def _compute_observables( force_field: smee.TensorForceField, output_dir: pathlib.Path, cached_dir: pathlib.Path | None, + simulation_config: dict[str, SimulationConfig] | None = None, ) -> _Observables: traj_hash = hashlib.sha256(pickle.dumps(key)).hexdigest() traj_name = f"{phase}-{traj_hash}-frames.msgpack" @@ -529,7 +578,12 @@ def _compute_observables( output_path = output_dir / traj_name - config = default_config(phase, key.temperature, key.pressure) + config = select_config( + phase=phase, + temperature=key.temperature, + pressure=key.pressure, + custom_config=simulation_config, + ) _simulate(system, force_field, config, output_path) return _Observables( @@ -646,6 +700,7 @@ def predict( cached_dir: pathlib.Path | None = None, per_type_scales: dict[DataType, float] | None = None, verbose: bool = False, + simulation_config: dict[str, SimulationConfig] | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Predict the properties in a dataset using molecular simulation, or by reweighting previous simulation data. @@ -661,15 +716,25 @@ def predict( per_type_scales: The scale factor to apply to each data type. A default of 1.0 will be used for any data type not specified. verbose: Whether to log additional information. + simulation_config: The (optional) simulation configuration, should contain + a config for each phase if not provided the default will be used. """ entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)] - required_simulations, entry_to_simulation = _plan_simulations(entries, topologies) + required_simulations, entry_to_simulation = _plan_simulations( + entries, topologies, simulation_config + ) observables = { phase: { key: _compute_observables( - phase, key, system, force_field, output_dir, cached_dir + phase, + key, + system, + force_field, + output_dir, + cached_dir, + simulation_config, ) for key, system in systems.items() } @@ -736,6 +801,7 @@ def default_closure( dataset: datasets.Dataset, per_type_scales: dict[DataType, float] | None = None, verbose: bool = False, + simulation_config: dict[str, SimulationConfig] | None = None, ) -> descent.optim.ClosureFn: """Return a default closure function for training against thermodynamic properties. @@ -747,6 +813,8 @@ def default_closure( dataset: The dataset to train against. per_type_scales: The scale factor to apply to each data type. verbose: Whether to log additional information about predictions. + simulation_config: The (optional) simulation configuration, should contain + a config for each phase if not provided the default will be used. Returns: The default closure function. @@ -767,6 +835,7 @@ def closure_fn( None, per_type_scales, verbose, + simulation_config, ) loss, gradient, hessian = ((y_pred - y_ref) ** 2).sum(), None, None diff --git a/descent/tests/targets/test_thermo.py b/descent/tests/targets/test_thermo.py index a32f199..8c9ee33 100644 --- a/descent/tests/targets/test_thermo.py +++ b/descent/tests/targets/test_thermo.py @@ -8,6 +8,7 @@ import descent.utils.dataset from descent.targets.thermo import ( DataEntry, + SimulationConfig, SimulationKey, _compute_observables, _convert_entry_to_system, @@ -21,6 +22,7 @@ default_config, extract_smiles, predict, + select_config, ) @@ -178,8 +180,60 @@ def test_default_config(phase, pressure, expected_n_mols): ) +def test_select_config(): + custom_config = { + "bulk": SimulationConfig( + max_mols=1000, + gen_coords=smee.mm.GenerateCoordsConfig(), + equilibrate=[ + smee.mm.MinimizationConfig(), + # short NVT equilibration simulation + smee.mm.SimulationConfig( + temperature=300 * openmm.unit.kelvin, + pressure=None, + n_steps=50000, + timestep=2.0 * openmm.unit.femtosecond, + ), + smee.mm.SimulationConfig( + temperature=300 * openmm.unit.kelvin, + pressure=1 * openmm.unit.atmosphere, + n_steps=100000, + timestep=2.0 * openmm.unit.femtosecond, + ), + ], + production=smee.mm.SimulationConfig( + temperature=300 * openmm.unit.kelvin, + pressure=1 * openmm.unit.atmosphere, + n_steps=1000000, + timestep=2.0 * openmm.unit.femtosecond, + ), + production_frequency=2000, + ) + } + temperature = 298.15 * openmm.unit.kelvin + pressure = 1 * openmm.unit.atmosphere + config = select_config( + phase="bulk", + temperature=temperature.value_in_unit(openmm.unit.kelvin), + pressure=pressure.value_in_unit(openmm.unit.atmosphere), + custom_config=custom_config, + ) + # make sure the custom config has been changed to match what was requested + for stage in config.equilibrate: + if isinstance(stage, smee.mm.SimulationConfig): + assert stage.temperature == temperature + assert stage.pressure == pressure + + assert config.production.temperature == temperature + assert config.production.pressure == pressure + assert config.max_mols == 1000 + + +@pytest.mark.parametrize( + "max_mols", [pytest.param(256, id="256"), pytest.param(1000, id="1000")] +) def test_plan_simulations( - mock_density_pure, mock_density_binary, mock_hvap, mock_hmix, mocker + mock_density_pure, mock_density_binary, mock_hvap, mock_hmix, mocker, max_mols ): topology_co = mocker.Mock() topology_cco = mocker.Mock() @@ -187,8 +241,26 @@ def test_plan_simulations( topologies = {"CO": topology_co, "CCO": topology_cco, "CCCC": topology_cccc} + # some mock config + custom_config = { + "bulk": SimulationConfig( + max_mols=max_mols, + gen_coords=smee.mm.GenerateCoordsConfig(), + equilibrate=[smee.mm.MinimizationConfig()], + production=smee.mm.SimulationConfig( + temperature=300 * openmm.unit.kelvin, + pressure=None, + n_steps=5000, + timestep=1.0 * openmm.unit.femtosecond, + ), + production_frequency=1000, + ) + } + required_simulations, entry_to_simulation = _plan_simulations( - [mock_density_pure, mock_density_binary, mock_hvap, mock_hmix], topologies + [mock_density_pure, mock_density_binary, mock_hvap, mock_hmix], + topologies, + custom_config, ) assert sorted(required_simulations) == ["bulk", "vacuum"] @@ -202,25 +274,25 @@ def test_plan_simulations( expected_cccc_key = SimulationKey( ("CCCC",), - (256,), + (max_mols,), mock_hvap["temperature"], mock_hvap["pressure"], ) expected_co_key = SimulationKey( ("CO",), - (256,), + (max_mols,), mock_density_pure["temperature"], mock_density_pure["pressure"], ) expected_cco_key = SimulationKey( ("CCO",), - (256,), + (max_mols,), mock_density_binary["temperature"], mock_density_binary["pressure"], ) expected_cco_co_key = SimulationKey( ("CCO", "CO"), - (128, 128), + (max_mols / 2, max_mols / 2), mock_density_binary["temperature"], mock_density_binary["pressure"], ) @@ -234,16 +306,19 @@ def test_plan_simulations( assert sorted(required_simulations["bulk"]) == sorted(expected_bulk_keys) - assert required_simulations["bulk"][expected_cccc_key].n_copies == [256] + assert required_simulations["bulk"][expected_cccc_key].n_copies == [max_mols] assert required_simulations["bulk"][expected_cccc_key].topologies == [topology_cccc] - assert required_simulations["bulk"][expected_cco_key].n_copies == [256] + assert required_simulations["bulk"][expected_cco_key].n_copies == [max_mols] assert required_simulations["bulk"][expected_cco_key].topologies == [topology_cco] - assert required_simulations["bulk"][expected_co_key].n_copies == [256] + assert required_simulations["bulk"][expected_co_key].n_copies == [max_mols] assert required_simulations["bulk"][expected_co_key].topologies == [topology_co] - assert required_simulations["bulk"][expected_cco_co_key].n_copies == [128, 128] + assert required_simulations["bulk"][expected_cco_co_key].n_copies == [ + max_mols / 2, + max_mols / 2, + ] assert required_simulations["bulk"][expected_cco_co_key].topologies == [ topology_cco, topology_co, @@ -334,7 +409,9 @@ def test_compute_observables_reweighted(tmp_cwd, mocker): expected_path = cached_dir / f"{phase}-{expected_hash}-frames.msgpack" expected_path.touch() - result = _compute_observables(phase, key, mock_system, mock_ff, tmp_cwd, cached_dir) + result = _compute_observables( + phase, key, mock_system, mock_ff, tmp_cwd, cached_dir, None + ) assert result.mean == mock_result assert {*result.std} == {*result.mean} @@ -375,7 +452,9 @@ def test_compute_observables_simulated(tmp_cwd, mocker): expected_path = tmp_cwd / f"{phase}-{expected_hash}-frames.msgpack" expected_path.touch() - result = _compute_observables(phase, key, mock_system, mock_ff, tmp_cwd, cached_dir) + result = _compute_observables( + phase, key, mock_system, mock_ff, tmp_cwd, cached_dir, None + ) assert result == mock_result mock_simulate.assert_called_once_with( @@ -540,6 +619,7 @@ def test_predict(tmp_cwd, mock_density_pure, mocker): mock_ff, tmp_cwd, None, + None, ) expected_y_ref = torch.tensor([mock_density_pure["value"] * mock_scale])