Skip to content
267 changes: 155 additions & 112 deletions descent/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import pytest
import smee
import smee.converters
import smee.utils
import torch


from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey


Expand Down Expand Up @@ -46,6 +48,45 @@ def mock_ff() -> smee.TensorForceField:
return ff


@pytest.fixture()
def water_sites_ff():
interchange = openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("tip4p_fb.offxml"),
openff.toolkit.Molecule.from_smiles("O").to_topology(),
)
ff, _ = smee.converters.convert_interchange(interchange)
# make sure we have vsites in the force field
assert ff.v_sites is not None
# this is awkward to specify in the yaml config file can we make it easier?
expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This id looks weird indeed. Can't it be simplify specified using ""[#1:2]-[#8X2H2+0:1]-[#1:3] EP"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"once" is added here https://github.com/openforcefield/openff-interchange/blob/c2f82bb4d4beceef80deda257155bec5c5b038a0/openff/interchange/smirnoff/_virtual_sites.py#L122 (as it matches once) so the ID ends up being as displayed. I can't think of a way to get round this without making the expected id less precise, but let me know if you have any thoughts!

vsite_ids = [key.id for key in ff.v_sites.keys]
assert vsite_ids == expected_ids

return ff


@pytest.fixture()
def mock_vsite_configs(water_sites_ff):
return ParameterConfig(
cols=["distance"],
scales={"distance": 10.0},
limits={"distance": (-1.0, -0.01)},
include=[water_sites_ff.v_sites.keys[0]],
)


@pytest.fixture()
def mock_water_parameter_config(water_sites_ff):
return {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
scales={"epsilon": 10.0, "sigma": 1.0},
limits={"epsilon": (0.0, None), "sigma": (0.0, None)},
include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]],
),
}


@pytest.fixture()
def mock_parameter_configs(mock_ff):
return {
Expand Down Expand Up @@ -286,148 +327,150 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs):
assert values.shape == expected_values.shape
assert torch.allclose(values, expected_values)

def test_regularized_idxs_no_regularization(
self, mock_ff, mock_parameter_configs, mock_attribute_configs
def test_init_vsites(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
trainable = Trainable(
mock_ff,
parameters=mock_parameter_configs,
attributes=mock_attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)

assert len(trainable.regularized_idxs) == 0
assert len(trainable.regularization_weights) == 0
assert trainable._param_types == ["vdW"]
# check we have a vdW parameter for the oxygen, hydrogen and vsite
assert trainable._param_shapes == [(3, 2)]
assert trainable._attr_types == []

def test_regularized_idxs_with_parameter_regularization(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
regularize={"epsilon": 0.01, "sigma": 0.001},
assert trainable._values.shape == (9,)
assert torch.allclose(
trainable._values,
torch.cat(
[
water_sites_ff.potentials_by_type["vdW"].parameters.flatten(),
water_sites_ff.v_sites.parameters.flatten(),
]
),
}
attribute_configs = {}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

# vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma
# So we should have 4 regularized values total: 2 epsilons + 2 sigmas
expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)
# check frozen parameters
# vdW params: eps, sig, eps, sig where only first smirks is unfrozen
# vsite params: dist, inplane, outplane where first smirks is unfrozen
expected_unfrozen_ids = torch.tensor([0, 1, 6])
assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all()

# Check the weights match what we configured
# Interleaved: row 0 (eps, sig), row 1 (eps, sig)
expected_weights = torch.tensor(
[0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype
assert torch.allclose(
trainable._clamp_lower,
torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64),
)
assert torch.allclose(
trainable._clamp_upper,
torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64),
)
assert torch.allclose(
trainable._scales,
torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64),
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_with_attribute_regularization(self, mock_ff):
parameter_configs = {}
attribute_configs = {
"vdW": AttributeConfig(
cols=["scale_14", "scale_15"],
regularize={"scale_14": 0.05},
)
}

def test_to_values_vsites(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)
vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten()
vsite_params = water_sites_ff.v_sites.parameters.flatten()

# Only scale_14 should be regularized (1 attribute)
expected_idxs = torch.tensor([0], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

expected_weights = torch.tensor(
[0.05], dtype=trainable.regularization_weights.dtype
expected_values = torch.tensor(
[
vdw_params[0] * 10, # scale eps
vdw_params[1], # sigma no scale
vsite_params[0] * 10, # scale vsite distance
]
)
assert torch.allclose(trainable.regularization_weights, expected_weights)
values = trainable.to_values()

def test_regularized_idxs_with_mixed_regularization(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
regularize={"epsilon": 0.02},
include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]],
),
}
attribute_configs = {
"vdW": AttributeConfig(
cols=["scale_14"],
regularize={"scale_14": 0.1},
)
}
assert values.shape == expected_values.shape
assert torch.allclose(values, expected_values)

def test_to_force_field_vsites_no_op(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
ff_initial = copy.deepcopy(water_sites_ff)

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)

# Only first vdW parameter row is included, with only epsilon regularized
# Plus scale_14 attribute
expected_idxs = torch.tensor([0, 2], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

# First should be epsilon (0.02), second should be scale_14 (0.1)
expected_weights = torch.tensor(
[0.02, 0.1], dtype=trainable.regularization_weights.dtype
ff = trainable.to_force_field(trainable.to_values())
assert (
ff.potentials_by_type["vdW"].parameters.shape
== ff_initial.potentials_by_type["vdW"].parameters.shape
)
assert torch.allclose(
ff.potentials_by_type["vdW"].parameters,
ff_initial.potentials_by_type["vdW"].parameters,
)
# vsite parameters are not float64 in the initial ff
vsite_initial = smee.utils.tensor_like(
ff_initial.v_sites.parameters, ff.v_sites.parameters
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_excluded_parameters(self, mock_ff):
parameter_configs = {
"Bonds": ParameterConfig(
cols=["k", "length"],
regularize={"k": 0.01, "length": 0.02},
exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]],
),
}
attribute_configs = {}
assert torch.allclose(
ff.v_sites.parameters,
vsite_initial,
)

def test_to_force_field_clamp_vsites(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)

# Only second bond parameter row should be included (first is excluded)
# Both k and length are regularized
expected_idxs = torch.tensor([0, 1], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)
# The trainable values are, in order, the vdW parameters (eps, sigma)
# followed by the vsite distance. # When we set the last trainable
# value to 0.0, this corresponds to the vsite distance, which is the first
# parameter in ff.v_sites.parameters.

expected_weights = torch.tensor(
[0.01, 0.02], dtype=trainable.regularization_weights.dtype
values = trainable.to_values().detach()
# set the distance to outside the clamp region
values[-1] = 0.0
ff = trainable.to_force_field(values)
assert torch.allclose(
ff.v_sites.parameters[0],
torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64),
Comment on lines +445 to +451
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this bit -- you set the distance of the last value to outside the clamp region, but then it's the first value that gets clamped to -0.01?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment to clarify.

)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularization_indices_match_unfrozen_values(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon"],
regularize={"epsilon": 0.01},
),
}
attribute_configs = {}
def test_init_vsites_regularization(
self, water_sites_ff, mock_water_parameter_config
):
vsite_config = ParameterConfig(
cols=["distance"],
scales={"distance": 10.0},
limits={"distance": (-1.0, -0.01)},
regularize={"distance": 0.25},
include=[water_sites_ff.v_sites.keys[0]],
)

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=vsite_config,
)

values = trainable.to_values()

# Regularization indices should be valid indices into the unfrozen values
assert trainable.regularized_idxs.max() < len(values)
assert trainable.regularized_idxs.min() >= 0

# We should be able to index the values tensor with regularization indices
regularized_values = values[trainable.regularized_idxs]
assert len(regularized_values) == len(trainable.regularized_idxs)
assert torch.equal(trainable.regularized_idxs, torch.tensor([2]))
assert torch.allclose(
trainable.regularization_weights,
torch.tensor([0.25], dtype=torch.float64),
)
Loading
Loading