-
Notifications
You must be signed in to change notification settings - Fork 6
Vsite training with regularization #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
438e5d2
c441511
c116a12
58deeec
8246eca
9a7744c
e05bc9f
a869d83
396e92c
0bfb99f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,8 +7,10 @@ | |
| import pytest | ||
| import smee | ||
| import smee.converters | ||
| import smee.utils | ||
| import torch | ||
|
|
||
|
|
||
| from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey | ||
|
|
||
|
|
||
|
|
@@ -46,6 +48,45 @@ def mock_ff() -> smee.TensorForceField: | |
| return ff | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def water_sites_ff(): | ||
| interchange = openff.interchange.Interchange.from_smirnoff( | ||
| openff.toolkit.ForceField("tip4p_fb.offxml"), | ||
| openff.toolkit.Molecule.from_smiles("O").to_topology(), | ||
| ) | ||
| ff, _ = smee.converters.convert_interchange(interchange) | ||
| # make sure we have vsites in the force field | ||
| assert ff.v_sites is not None | ||
| # this is awkward to specify in the yaml config file can we make it easier? | ||
| expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"] | ||
| vsite_ids = [key.id for key in ff.v_sites.keys] | ||
| assert vsite_ids == expected_ids | ||
|
|
||
| return ff | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def mock_vsite_configs(water_sites_ff): | ||
| return ParameterConfig( | ||
| cols=["distance"], | ||
| scales={"distance": 10.0}, | ||
| limits={"distance": (-1.0, -0.01)}, | ||
| include=[water_sites_ff.v_sites.keys[0]], | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def mock_water_parameter_config(water_sites_ff): | ||
| return { | ||
| "vdW": ParameterConfig( | ||
| cols=["epsilon", "sigma"], | ||
| scales={"epsilon": 10.0, "sigma": 1.0}, | ||
| limits={"epsilon": (0.0, None), "sigma": (0.0, None)}, | ||
| include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]], | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| @pytest.fixture() | ||
| def mock_parameter_configs(mock_ff): | ||
| return { | ||
|
|
@@ -286,148 +327,150 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs): | |
| assert values.shape == expected_values.shape | ||
| assert torch.allclose(values, expected_values) | ||
|
|
||
| def test_regularized_idxs_no_regularization( | ||
| self, mock_ff, mock_parameter_configs, mock_attribute_configs | ||
| def test_init_vsites( | ||
| self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config | ||
| ): | ||
| trainable = Trainable( | ||
| mock_ff, | ||
| parameters=mock_parameter_configs, | ||
| attributes=mock_attribute_configs, | ||
| water_sites_ff, | ||
| parameters=mock_water_parameter_config, | ||
| attributes={}, | ||
| vsites=mock_vsite_configs, | ||
| ) | ||
|
|
||
| assert len(trainable.regularized_idxs) == 0 | ||
| assert len(trainable.regularization_weights) == 0 | ||
| assert trainable._param_types == ["vdW"] | ||
| # check we have a vdW parameter for the oxygen, hydrogen and vsite | ||
| assert trainable._param_shapes == [(3, 2)] | ||
| assert trainable._attr_types == [] | ||
|
|
||
| def test_regularized_idxs_with_parameter_regularization(self, mock_ff): | ||
| parameter_configs = { | ||
| "vdW": ParameterConfig( | ||
| cols=["epsilon", "sigma"], | ||
| regularize={"epsilon": 0.01, "sigma": 0.001}, | ||
| assert trainable._values.shape == (9,) | ||
| assert torch.allclose( | ||
| trainable._values, | ||
| torch.cat( | ||
| [ | ||
| water_sites_ff.potentials_by_type["vdW"].parameters.flatten(), | ||
| water_sites_ff.v_sites.parameters.flatten(), | ||
| ] | ||
| ), | ||
| } | ||
| attribute_configs = {} | ||
|
|
||
| trainable = Trainable( | ||
| mock_ff, | ||
| parameters=parameter_configs, | ||
| attributes=attribute_configs, | ||
| ) | ||
|
|
||
| # vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma | ||
| # So we should have 4 regularized values total: 2 epsilons + 2 sigmas | ||
| expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long) | ||
| assert torch.equal(trainable.regularized_idxs, expected_idxs) | ||
| # check frozen parameters | ||
| # vdW params: eps, sig, eps, sig where only first smirks is unfrozen | ||
| # vsite params: dist, inplane, outplane where first smirks is unfrozen | ||
| expected_unfrozen_ids = torch.tensor([0, 1, 6]) | ||
| assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all() | ||
|
|
||
| # Check the weights match what we configured | ||
| # Interleaved: row 0 (eps, sig), row 1 (eps, sig) | ||
| expected_weights = torch.tensor( | ||
| [0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype | ||
| assert torch.allclose( | ||
| trainable._clamp_lower, | ||
| torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64), | ||
| ) | ||
| assert torch.allclose( | ||
| trainable._clamp_upper, | ||
| torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64), | ||
| ) | ||
| assert torch.allclose( | ||
| trainable._scales, | ||
| torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64), | ||
| ) | ||
| assert torch.allclose(trainable.regularization_weights, expected_weights) | ||
|
|
||
| def test_regularized_idxs_with_attribute_regularization(self, mock_ff): | ||
| parameter_configs = {} | ||
| attribute_configs = { | ||
| "vdW": AttributeConfig( | ||
| cols=["scale_14", "scale_15"], | ||
| regularize={"scale_14": 0.05}, | ||
| ) | ||
| } | ||
|
|
||
| def test_to_values_vsites( | ||
| self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config | ||
| ): | ||
| trainable = Trainable( | ||
| mock_ff, | ||
| parameters=parameter_configs, | ||
| attributes=attribute_configs, | ||
| water_sites_ff, | ||
| parameters=mock_water_parameter_config, | ||
| attributes={}, | ||
| vsites=mock_vsite_configs, | ||
| ) | ||
| vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten() | ||
| vsite_params = water_sites_ff.v_sites.parameters.flatten() | ||
|
|
||
| # Only scale_14 should be regularized (1 attribute) | ||
| expected_idxs = torch.tensor([0], dtype=torch.long) | ||
| assert torch.equal(trainable.regularized_idxs, expected_idxs) | ||
|
|
||
| expected_weights = torch.tensor( | ||
| [0.05], dtype=trainable.regularization_weights.dtype | ||
| expected_values = torch.tensor( | ||
| [ | ||
| vdw_params[0] * 10, # scale eps | ||
| vdw_params[1], # sigma no scale | ||
| vsite_params[0] * 10, # scale vsite distance | ||
| ] | ||
| ) | ||
| assert torch.allclose(trainable.regularization_weights, expected_weights) | ||
| values = trainable.to_values() | ||
|
|
||
| def test_regularized_idxs_with_mixed_regularization(self, mock_ff): | ||
| parameter_configs = { | ||
| "vdW": ParameterConfig( | ||
| cols=["epsilon", "sigma"], | ||
| regularize={"epsilon": 0.02}, | ||
| include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]], | ||
| ), | ||
| } | ||
| attribute_configs = { | ||
| "vdW": AttributeConfig( | ||
| cols=["scale_14"], | ||
| regularize={"scale_14": 0.1}, | ||
| ) | ||
| } | ||
| assert values.shape == expected_values.shape | ||
| assert torch.allclose(values, expected_values) | ||
|
|
||
| def test_to_force_field_vsites_no_op( | ||
| self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config | ||
| ): | ||
| ff_initial = copy.deepcopy(water_sites_ff) | ||
|
|
||
| trainable = Trainable( | ||
| mock_ff, | ||
| parameters=parameter_configs, | ||
| attributes=attribute_configs, | ||
| water_sites_ff, | ||
| parameters=mock_water_parameter_config, | ||
| attributes={}, | ||
| vsites=mock_vsite_configs, | ||
| ) | ||
|
|
||
| # Only first vdW parameter row is included, with only epsilon regularized | ||
| # Plus scale_14 attribute | ||
| expected_idxs = torch.tensor([0, 2], dtype=torch.long) | ||
| assert torch.equal(trainable.regularized_idxs, expected_idxs) | ||
|
|
||
| # First should be epsilon (0.02), second should be scale_14 (0.1) | ||
| expected_weights = torch.tensor( | ||
| [0.02, 0.1], dtype=trainable.regularization_weights.dtype | ||
| ff = trainable.to_force_field(trainable.to_values()) | ||
| assert ( | ||
| ff.potentials_by_type["vdW"].parameters.shape | ||
| == ff_initial.potentials_by_type["vdW"].parameters.shape | ||
| ) | ||
| assert torch.allclose( | ||
| ff.potentials_by_type["vdW"].parameters, | ||
| ff_initial.potentials_by_type["vdW"].parameters, | ||
| ) | ||
| # vsite parameters are not float64 in the initial ff | ||
| vsite_initial = smee.utils.tensor_like( | ||
| ff_initial.v_sites.parameters, ff.v_sites.parameters | ||
| ) | ||
| assert torch.allclose(trainable.regularization_weights, expected_weights) | ||
|
|
||
| def test_regularized_idxs_excluded_parameters(self, mock_ff): | ||
| parameter_configs = { | ||
| "Bonds": ParameterConfig( | ||
| cols=["k", "length"], | ||
| regularize={"k": 0.01, "length": 0.02}, | ||
| exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]], | ||
| ), | ||
| } | ||
| attribute_configs = {} | ||
| assert torch.allclose( | ||
| ff.v_sites.parameters, | ||
| vsite_initial, | ||
| ) | ||
|
|
||
| def test_to_force_field_clamp_vsites( | ||
| self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config | ||
| ): | ||
| trainable = Trainable( | ||
| mock_ff, | ||
| parameters=parameter_configs, | ||
| attributes=attribute_configs, | ||
| water_sites_ff, | ||
| parameters=mock_water_parameter_config, | ||
| attributes={}, | ||
| vsites=mock_vsite_configs, | ||
| ) | ||
|
|
||
| # Only second bond parameter row should be included (first is excluded) | ||
| # Both k and length are regularized | ||
| expected_idxs = torch.tensor([0, 1], dtype=torch.long) | ||
| assert torch.equal(trainable.regularized_idxs, expected_idxs) | ||
| # The trainable values are, in order, the vdW parameters (eps, sigma) | ||
| # followed by the vsite distance. # When we set the last trainable | ||
| # value to 0.0, this corresponds to the vsite distance, which is the first | ||
| # parameter in ff.v_sites.parameters. | ||
|
|
||
| expected_weights = torch.tensor( | ||
| [0.01, 0.02], dtype=trainable.regularization_weights.dtype | ||
| values = trainable.to_values().detach() | ||
| # set the distance to outside the clamp region | ||
| values[-1] = 0.0 | ||
| ff = trainable.to_force_field(values) | ||
| assert torch.allclose( | ||
| ff.v_sites.parameters[0], | ||
| torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64), | ||
|
Comment on lines
+445
to
+451
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused by this bit -- you set the distance of the last value to outside the clamp region, but then it's the first value that gets clamped to -0.01?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a comment to clarify. |
||
| ) | ||
| assert torch.allclose(trainable.regularization_weights, expected_weights) | ||
|
|
||
| def test_regularization_indices_match_unfrozen_values(self, mock_ff): | ||
| parameter_configs = { | ||
| "vdW": ParameterConfig( | ||
| cols=["epsilon"], | ||
| regularize={"epsilon": 0.01}, | ||
| ), | ||
| } | ||
| attribute_configs = {} | ||
| def test_init_vsites_regularization( | ||
| self, water_sites_ff, mock_water_parameter_config | ||
| ): | ||
| vsite_config = ParameterConfig( | ||
| cols=["distance"], | ||
| scales={"distance": 10.0}, | ||
| limits={"distance": (-1.0, -0.01)}, | ||
| regularize={"distance": 0.25}, | ||
| include=[water_sites_ff.v_sites.keys[0]], | ||
| ) | ||
|
|
||
| trainable = Trainable( | ||
| mock_ff, | ||
| parameters=parameter_configs, | ||
| attributes=attribute_configs, | ||
| water_sites_ff, | ||
| parameters=mock_water_parameter_config, | ||
| attributes={}, | ||
| vsites=vsite_config, | ||
| ) | ||
|
|
||
| values = trainable.to_values() | ||
|
|
||
| # Regularization indices should be valid indices into the unfrozen values | ||
| assert trainable.regularized_idxs.max() < len(values) | ||
| assert trainable.regularized_idxs.min() >= 0 | ||
|
|
||
| # We should be able to index the values tensor with regularization indices | ||
| regularized_values = values[trainable.regularized_idxs] | ||
| assert len(regularized_values) == len(trainable.regularized_idxs) | ||
| assert torch.equal(trainable.regularized_idxs, torch.tensor([2])) | ||
| assert torch.allclose( | ||
| trainable.regularization_weights, | ||
| torch.tensor([0.25], dtype=torch.float64), | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This id looks weird indeed. Can't it be simplify specified using
""[#1:2]-[#8X2H2+0:1]-[#1:3] EP"?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"once" is added here https://github.com/openforcefield/openff-interchange/blob/c2f82bb4d4beceef80deda257155bec5c5b038a0/openff/interchange/smirnoff/_virtual_sites.py#L122 (as it matches once) so the ID ends up being as displayed. I can't think of a way to get round this without making the expected id less precise, but let me know if you have any thoughts!