diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 1b445ee..1aae364 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -7,6 +7,7 @@ import pytest import smee import smee.converters +import smee.utils import torch from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey @@ -46,6 +47,45 @@ def mock_ff() -> smee.TensorForceField: return ff +@pytest.fixture() +def water_sites_ff(): + interachange = openff.interchange.Interchange.from_smirnoff( + openff.toolkit.ForceField("tip4p_fb.offxml"), + openff.toolkit.Molecule.from_smiles("O").to_topology(), + ) + ff, _ = smee.converters.convert_interchange(interachange) + # make sure we have vsites in the force field + assert ff.v_sites is not None + # this is awkward to specify in the yaml config file can we make it easier? + expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"] + vsite_ids = [key.id for key in ff.v_sites.keys] + assert vsite_ids == expected_ids + + return ff + + +@pytest.fixture() +def mock_vsite_configs(water_sites_ff): + return ParameterConfig( + cols=["distance"], + scales={"distance": 10.0}, + limits={"distance": (-1.0, -0.01)}, + include=[water_sites_ff.v_sites.keys[0]], + ) + + +@pytest.fixture() +def mock_water_parameter_config(water_sites_ff): + return { + "vdW": ParameterConfig( + cols=["epsilon", "sigma"], + scales={"epsilon": 10.0, "sigma": 1.0}, + limits={"epsilon": (0.0, None), "sigma": (0.0, None)}, + include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]], + ), + } + + @pytest.fixture() def mock_parameter_configs(mock_ff): return { @@ -286,148 +326,121 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs): assert values.shape == expected_values.shape assert torch.allclose(values, expected_values) - def test_regularized_idxs_no_regularization( - self, mock_ff, mock_parameter_configs, mock_attribute_configs + def test_init_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config ): trainable = Trainable( - mock_ff, - parameters=mock_parameter_configs, - attributes=mock_attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) - assert len(trainable.regularized_idxs) == 0 - assert len(trainable.regularization_weights) == 0 + assert trainable._param_types == ["vdW"] + # check we have a vdW parameter for the oxygen, hydrogen and vsite + assert trainable._param_shapes == [(3, 2)] + assert trainable._attr_types == [] - def test_regularized_idxs_with_parameter_regularization(self, mock_ff): - parameter_configs = { - "vdW": ParameterConfig( - cols=["epsilon", "sigma"], - regularize={"epsilon": 0.01, "sigma": 0.001}, + assert trainable._values.shape == (9,) + assert torch.allclose( + trainable._values, + torch.cat( + [ + water_sites_ff.potentials_by_type["vdW"].parameters.flatten(), + water_sites_ff.v_sites.parameters.flatten(), + ] ), - } - attribute_configs = {} - - trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, ) - # vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma - # So we should have 4 regularized values total: 2 epsilons + 2 sigmas - expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) + # check frozen parameters + # vdW params: eps, sig, eps, sig where only first smirks is unfrozen + # vsite params: dist, inplane, outplane where first smirks is unfrozen + expected_unfrozen_ids = torch.tensor([0, 1, 6]) + assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all() - # Check the weights match what we configured - # Interleaved: row 0 (eps, sig), row 1 (eps, sig) - expected_weights = torch.tensor( - [0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype + assert torch.allclose( + trainable._clamp_lower, + torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64), ) - assert torch.allclose(trainable.regularization_weights, expected_weights) - - def test_regularized_idxs_with_attribute_regularization(self, mock_ff): - parameter_configs = {} - attribute_configs = { - "vdW": AttributeConfig( - cols=["scale_14", "scale_15"], - regularize={"scale_14": 0.05}, - ) - } - - trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + assert torch.allclose( + trainable._clamp_upper, + torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64), ) - - # Only scale_14 should be regularized (1 attribute) - expected_idxs = torch.tensor([0], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) - - expected_weights = torch.tensor( - [0.05], dtype=trainable.regularization_weights.dtype + assert torch.allclose( + trainable._scales, + torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64), ) - assert torch.allclose(trainable.regularization_weights, expected_weights) - - def test_regularized_idxs_with_mixed_regularization(self, mock_ff): - parameter_configs = { - "vdW": ParameterConfig( - cols=["epsilon", "sigma"], - regularize={"epsilon": 0.02}, - include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]], - ), - } - attribute_configs = { - "vdW": AttributeConfig( - cols=["scale_14"], - regularize={"scale_14": 0.1}, - ) - } + def test_to_values_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) + vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten() + vsite_params = water_sites_ff.v_sites.parameters.flatten() - # Only first vdW parameter row is included, with only epsilon regularized - # Plus scale_14 attribute - expected_idxs = torch.tensor([0, 2], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) - - # First should be epsilon (0.02), second should be scale_14 (0.1) - expected_weights = torch.tensor( - [0.02, 0.1], dtype=trainable.regularization_weights.dtype + expected_values = torch.tensor( + [ + vdw_params[0] * 10, # scale eps + vdw_params[1], # sigma no scale + vsite_params[0] * 10, # scale vsite distance + ] ) - assert torch.allclose(trainable.regularization_weights, expected_weights) + values = trainable.to_values() - def test_regularized_idxs_excluded_parameters(self, mock_ff): - parameter_configs = { - "Bonds": ParameterConfig( - cols=["k", "length"], - regularize={"k": 0.01, "length": 0.02}, - exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]], - ), - } - attribute_configs = {} + assert values.shape == expected_values.shape + assert torch.allclose(values, expected_values) + + def test_to_force_field_vsites_no_op( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + ff_initial = copy.deepcopy(water_sites_ff) trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) - # Only second bond parameter row should be included (first is excluded) - # Both k and length are regularized - expected_idxs = torch.tensor([0, 1], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) - - expected_weights = torch.tensor( - [0.01, 0.02], dtype=trainable.regularization_weights.dtype + ff = trainable.to_force_field(trainable.to_values()) + assert ( + ff.potentials_by_type["vdW"].parameters.shape + == ff_initial.potentials_by_type["vdW"].parameters.shape + ) + assert torch.allclose( + ff.potentials_by_type["vdW"].parameters, + ff_initial.potentials_by_type["vdW"].parameters, + ) + # vsite parameters are not float64 in the initial ff + vsite_initial = smee.utils.tensor_like( + ff_initial.v_sites.parameters, ff.v_sites.parameters ) - assert torch.allclose(trainable.regularization_weights, expected_weights) - def test_regularization_indices_match_unfrozen_values(self, mock_ff): - parameter_configs = { - "vdW": ParameterConfig( - cols=["epsilon"], - regularize={"epsilon": 0.01}, - ), - } - attribute_configs = {} + assert torch.allclose( + ff.v_sites.parameters, + vsite_initial, + ) + def test_to_force_field_clamp_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) - values = trainable.to_values() - - # Regularization indices should be valid indices into the unfrozen values - assert trainable.regularized_idxs.max() < len(values) - assert trainable.regularized_idxs.min() >= 0 - - # We should be able to index the values tensor with regularization indices - regularized_values = values[trainable.regularized_idxs] - assert len(regularized_values) == len(trainable.regularized_idxs) + values = trainable.to_values().detach() + # set the distance to outside the clamp region + values[-1] = 0.0 + ff = trainable.to_force_field(values) + assert torch.allclose( + ff.v_sites.parameters[0], + torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64), + ) diff --git a/descent/train.py b/descent/train.py index 1dcb40f..4a2ee59 100644 --- a/descent/train.py +++ b/descent/train.py @@ -326,11 +326,68 @@ def _prepare( ), ) + def _prepare_vsites( + self, force_field: smee.TensorForceField, config: ParameterConfig + ): + """ + Prepare the vsite parameters for optimisation. + + Args: + force_field: The tensor force field with parameters + which should be optimised. + config: The config of the parameters to train. + + Returns: + + """ + vsite_parameters = force_field.v_sites.parameters.detach().clone() + n_rows = vsite_parameters.shape[0] + vsite_parameters_flat = vsite_parameters.flatten() + # define the cols as they are not on the tensor model + vsite_cols = ["distance", "inPlaneAngle", "outOfPlaneAngle"] + + all_keys = [_PotentialKey(**key.dict()) for key in force_field.v_sites.keys] + excluded_keys = config.exclude or [] + unfrozen_keys = config.include or all_keys + + key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} + assert len(key_to_row) == len(all_keys), "duplicate keys found" + + unfrozen_rows = { + key_to_row[key] for key in unfrozen_keys if key not in excluded_keys + } + + unfrozen_idxs = [ + col_idx + row_idx * vsite_parameters.shape[1] + for row_idx in range(n_rows) + if row_idx in unfrozen_rows + # the vsite model has no parameter cols so define here + for col_idx, col in enumerate(vsite_cols) + if col in config.cols + ] + vsite_scales = [config.scales.get(col, 1.0) for col in vsite_cols] * n_rows + clamp_lower = [ + config.limits.get(col, (None, None))[0] for col in vsite_cols + ] * n_rows + clamp_lower = [-torch.inf if x is None else x for x in clamp_lower] + clamp_upper = [ + config.limits.get(col, (None, None))[1] for col in vsite_cols + ] * n_rows + clamp_upper = [torch.inf if x is None else x for x in clamp_upper] + return ( + vsite_parameters_flat, + torch.tensor(unfrozen_idxs), + smee.utils.tensor_like(vsite_scales, vsite_parameters), + smee.utils.tensor_like(clamp_lower, vsite_parameters), + smee.utils.tensor_like(clamp_upper, vsite_parameters), + ) + def __init__( self, force_field: smee.TensorForceField, parameters: dict[str, ParameterConfig], attributes: dict[str, AttributeConfig], + vsites: ParameterConfig | None = None, ): """ @@ -338,8 +395,10 @@ def __init__( force_field: The force field to wrap. parameters: Configure which parameters to train. attributes: Configure which attributes to train. + vsites: Configure which vsite parameters to train. """ self._force_field = force_field + self._fit_vsites = False ( self._param_types, @@ -364,20 +423,38 @@ def __init__( attr_regularization_weights, ) = self._prepare(force_field, attributes, "attributes") - self._values = torch.cat([param_values, attr_values]) + values = [param_values, attr_values] + unfrozen_idxs = [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] + scales = [param_scales, attr_scales] + clamp_lower = [param_clamp_lower, attr_clamp_lower] + clamp_upper = [param_clamp_upper, attr_clamp_upper] - self._unfrozen_idxs = torch.cat( - [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] - ).long() + if vsites is not None: + ( + vsite_values, + vsite_unfrozen_idxs, + vsite_scales, + vsite_clamp_lower, + vsite_clamp_upper, + ) = self._prepare_vsites(force_field, vsites) + self._fit_vsites = True + + values.append(vsite_values) + unfrozen_idxs.append( + (vsite_unfrozen_idxs + len(param_scales) + len(attr_scales)) + ) + scales.append(vsite_scales) + clamp_lower.append(vsite_clamp_lower) + clamp_upper.append(vsite_clamp_upper) - self._scales = torch.cat([param_scales, attr_scales])[self._unfrozen_idxs] + self._values = torch.cat(values) - self._clamp_lower = torch.cat([param_clamp_lower, attr_clamp_lower])[ - self._unfrozen_idxs - ] - self._clamp_upper = torch.cat([param_clamp_upper, attr_clamp_upper])[ - self._unfrozen_idxs - ] + self._unfrozen_idxs = torch.cat(unfrozen_idxs).long() + + self._scales = torch.cat(scales)[self._unfrozen_idxs] + + self._clamp_lower = torch.cat(clamp_lower)[self._unfrozen_idxs] + self._clamp_upper = torch.cat(clamp_upper)[self._unfrozen_idxs] # Store regularization information all_regularized_idxs = torch.cat( @@ -427,18 +504,29 @@ def to_force_field(self, values_flat: torch.Tensor) -> smee.TensorForceField: values[self._unfrozen_idxs] = (values_flat / self._scales).clamp( min=self._clamp_lower, max=self._clamp_upper ) - values = _unflatten_tensors(values, self._param_shapes + self._attr_shapes) + shapes = self._param_shapes + self._attr_shapes + + if self._fit_vsites: + shapes.append(self._force_field.v_sites.parameters.shape) + + values = _unflatten_tensors(values, shapes) params = values[: len(self._param_shapes)] for potential_type, param in zip(self._param_types, params, strict=True): potentials[potential_type].parameters = param - attrs = values[len(self._param_shapes) :] + attrs = values[ + len(self._param_shapes) : len(self._param_shapes) + len(self._attr_shapes) + ] for potential_type, attr in zip(self._attr_types, attrs, strict=True): potentials[potential_type].attributes = attr + if self._fit_vsites: + vsite_params = values[len(self._param_shapes) + len(self._attr_shapes) :] + self._force_field.v_sites.parameters = vsite_params[0] + return self._force_field @torch.no_grad()