From 438e5d23c177b050233a644d643fa11141e46da0 Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Fri, 30 Aug 2024 10:18:49 +0100 Subject: [PATCH] include vsite fitting in the trainer --- descent/tests/test_train.py | 159 ++++++++++++++++++++++++++++++++++++ descent/train.py | 114 +++++++++++++++++++++++--- 2 files changed, 260 insertions(+), 13 deletions(-) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 3884a55..d2e39ee 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -7,6 +7,7 @@ import pytest import smee import smee.converters +import smee.utils import torch from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey @@ -46,6 +47,45 @@ def mock_ff() -> smee.TensorForceField: return ff +@pytest.fixture() +def water_sites_ff(): + interachange = openff.interchange.Interchange.from_smirnoff( + openff.toolkit.ForceField("tip4p_fb.offxml"), + openff.toolkit.Molecule.from_smiles("O").to_topology(), + ) + ff, _ = smee.converters.convert_interchange(interachange) + # make sure we have vsites in the force field + assert ff.v_sites is not None + # this is awkward to specify in the yaml config file can we make it easier? + expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"] + vsite_ids = [key.id for key in ff.v_sites.keys] + assert vsite_ids == expected_ids + + return ff + + +@pytest.fixture() +def mock_vsite_configs(water_sites_ff): + return ParameterConfig( + cols=["distance"], + scales={"distance": 10.0}, + limits={"distance": (-1.0, -0.01)}, + include=[water_sites_ff.v_sites.keys[0]], + ) + + +@pytest.fixture() +def mock_water_parameter_config(water_sites_ff): + return { + "vdW": ParameterConfig( + cols=["epsilon", "sigma"], + scales={"epsilon": 10.0, "sigma": 1.0}, + limits={"epsilon": (0.0, None), "sigma": (0.0, None)}, + include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]], + ), + } + + @pytest.fixture() def mock_parameter_configs(mock_ff): return { @@ -268,3 +308,122 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs): assert values.shape == expected_values.shape assert torch.allclose(values, expected_values) + + def test_init_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + + assert trainable._param_types == ["vdW"] + # check we have a vdW parameter for the oxygen, hydrogen and vsite + assert trainable._param_shapes == [(3, 2)] + assert trainable._attr_types == [] + + assert trainable._values.shape == (9,) + assert torch.allclose( + trainable._values, + torch.cat( + [ + water_sites_ff.potentials_by_type["vdW"].parameters.flatten(), + water_sites_ff.v_sites.parameters.flatten(), + ] + ), + ) + + # check frozen parameters + # vdW params: eps, sig, eps, sig where only first smirks is unfrozen + # vsite params: dist, inplane, outplane where first smirks is unfrozen + expected_unfrozen_ids = torch.tensor([0, 1, 6]) + assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all() + + assert torch.allclose( + trainable._clamp_lower, + torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64), + ) + assert torch.allclose( + trainable._clamp_upper, + torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64), + ) + assert torch.allclose( + trainable._scales, + torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64), + ) + + def test_to_values_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten() + vsite_params = water_sites_ff.v_sites.parameters.flatten() + + expected_values = torch.tensor( + [ + vdw_params[0] * 10, # scale eps + vdw_params[1], # sigma no scale + vsite_params[0] * 10, # scale vsite distance + ] + ) + values = trainable.to_values() + + assert values.shape == expected_values.shape + assert torch.allclose(values, expected_values) + + def test_to_force_field_vsites_no_op( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + ff_initial = copy.deepcopy(water_sites_ff) + + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + + ff = trainable.to_force_field(trainable.to_values()) + assert ( + ff.potentials_by_type["vdW"].parameters.shape + == ff_initial.potentials_by_type["vdW"].parameters.shape + ) + assert torch.allclose( + ff.potentials_by_type["vdW"].parameters, + ff_initial.potentials_by_type["vdW"].parameters, + ) + # vsite parameters are not float64 in the initial ff + vsite_initial = smee.utils.tensor_like( + ff_initial.v_sites.parameters, ff.v_sites.parameters + ) + + assert torch.allclose( + ff.v_sites.parameters, + vsite_initial, + ) + + def test_to_force_field_clamp_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + + values = trainable.to_values().detach() + # set the distance to outside the clamp region + values[-1] = 0.0 + ff = trainable.to_force_field(values) + assert torch.allclose( + ff.v_sites.parameters[0], + torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64), + ) diff --git a/descent/train.py b/descent/train.py index 0c7d938..142b52d 100644 --- a/descent/train.py +++ b/descent/train.py @@ -251,11 +251,68 @@ def _prepare( smee.utils.tensor_like(clamp_upper, values), ) + def _prepare_vsites( + self, force_field: smee.TensorForceField, config: ParameterConfig + ): + """ + Prepare the vsite parameters for optimisation. + + Args: + force_field: The tensor force field with parameters + which should be optimised. + config: The config of the parameters to train. + + Returns: + + """ + vsite_parameters = force_field.v_sites.parameters.detach().clone() + n_rows = vsite_parameters.shape[0] + vsite_parameters_flat = vsite_parameters.flatten() + # define the cols as they are not on the tensor model + vsite_cols = ["distance", "inPlaneAngle", "outOfPlaneAngle"] + + all_keys = [_PotentialKey(**key.dict()) for key in force_field.v_sites.keys] + excluded_keys = config.exclude or [] + unfrozen_keys = config.include or all_keys + + key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} + assert len(key_to_row) == len(all_keys), "duplicate keys found" + + unfrozen_rows = { + key_to_row[key] for key in unfrozen_keys if key not in excluded_keys + } + + unfrozen_idxs = [ + col_idx + row_idx * vsite_parameters.shape[1] + for row_idx in range(n_rows) + if row_idx in unfrozen_rows + # the vsite model has no parameter cols so define here + for col_idx, col in enumerate(vsite_cols) + if col in config.cols + ] + vsite_scales = [config.scales.get(col, 1.0) for col in vsite_cols] * n_rows + clamp_lower = [ + config.limits.get(col, (None, None))[0] for col in vsite_cols + ] * n_rows + clamp_lower = [-torch.inf if x is None else x for x in clamp_lower] + clamp_upper = [ + config.limits.get(col, (None, None))[1] for col in vsite_cols + ] * n_rows + clamp_upper = [torch.inf if x is None else x for x in clamp_upper] + return ( + vsite_parameters_flat, + torch.tensor(unfrozen_idxs), + smee.utils.tensor_like(vsite_scales, vsite_parameters), + smee.utils.tensor_like(clamp_lower, vsite_parameters), + smee.utils.tensor_like(clamp_upper, vsite_parameters), + ) + def __init__( self, force_field: smee.TensorForceField, parameters: dict[str, ParameterConfig], attributes: dict[str, AttributeConfig], + vsites: ParameterConfig | None = None, ): """ @@ -263,8 +320,10 @@ def __init__( force_field: The force field to wrap. parameters: Configure which parameters to train. attributes: Configure which attributes to train. + vsites: Configure which vsite parameters to train. """ self._force_field = force_field + self._fit_vsites = False ( self._param_types, @@ -285,20 +344,38 @@ def __init__( attr_clamp_upper, ) = self._prepare(force_field, attributes, "attributes") - self._values = torch.cat([param_values, attr_values]) + values = [param_values, attr_values] + unfrozen_idxs = [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] + scales = [param_scales, attr_scales] + clamp_lower = [param_clamp_lower, attr_clamp_lower] + clamp_upper = [param_clamp_upper, attr_clamp_upper] + + if vsites is not None: + ( + vsite_values, + vsite_unfrozen_idxs, + vsite_scales, + vsite_clamp_lower, + vsite_clamp_upper, + ) = self._prepare_vsites(force_field, vsites) + self._fit_vsites = True + + values.append(vsite_values) + unfrozen_idxs.append( + (vsite_unfrozen_idxs + len(param_scales) + len(attr_scales)) + ) + scales.append(vsite_scales) + clamp_lower.append(vsite_clamp_lower) + clamp_upper.append(vsite_clamp_upper) - self._unfrozen_idxs = torch.cat( - [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] - ).long() + self._values = torch.cat(values) - self._scales = torch.cat([param_scales, attr_scales])[self._unfrozen_idxs] + self._unfrozen_idxs = torch.cat(unfrozen_idxs).long() - self._clamp_lower = torch.cat([param_clamp_lower, attr_clamp_lower])[ - self._unfrozen_idxs - ] - self._clamp_upper = torch.cat([param_clamp_upper, attr_clamp_upper])[ - self._unfrozen_idxs - ] + self._scales = torch.cat(scales)[self._unfrozen_idxs] + + self._clamp_lower = torch.cat(clamp_lower)[self._unfrozen_idxs] + self._clamp_upper = torch.cat(clamp_upper)[self._unfrozen_idxs] @torch.no_grad() def to_values(self) -> torch.Tensor: @@ -320,18 +397,29 @@ def to_force_field(self, values_flat: torch.Tensor) -> smee.TensorForceField: values[self._unfrozen_idxs] = (values_flat / self._scales).clamp( min=self._clamp_lower, max=self._clamp_upper ) - values = _unflatten_tensors(values, self._param_shapes + self._attr_shapes) + shapes = self._param_shapes + self._attr_shapes + + if self._fit_vsites: + shapes.append(self._force_field.v_sites.parameters.shape) + + values = _unflatten_tensors(values, shapes) params = values[: len(self._param_shapes)] for potential_type, param in zip(self._param_types, params, strict=True): potentials[potential_type].parameters = param - attrs = values[len(self._param_shapes) :] + attrs = values[ + len(self._param_shapes) : len(self._param_shapes) + len(self._attr_shapes) + ] for potential_type, attr in zip(self._attr_types, attrs, strict=True): potentials[potential_type].attributes = attr + if self._fit_vsites: + vsite_params = values[len(self._param_shapes) + len(self._attr_shapes) :] + self._force_field.v_sites.parameters = vsite_params[0] + return self._force_field @torch.no_grad()