Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 130 additions & 117 deletions descent/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import smee
import smee.converters
import smee.utils
import torch

from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey
Expand Down Expand Up @@ -46,6 +47,45 @@ def mock_ff() -> smee.TensorForceField:
return ff


@pytest.fixture()
def water_sites_ff():
interachange = openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("tip4p_fb.offxml"),
openff.toolkit.Molecule.from_smiles("O").to_topology(),
)
ff, _ = smee.converters.convert_interchange(interachange)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny typo: "interachange"

# make sure we have vsites in the force field
assert ff.v_sites is not None
# this is awkward to specify in the yaml config file can we make it easier?
expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"]
vsite_ids = [key.id for key in ff.v_sites.keys]
assert vsite_ids == expected_ids

return ff


@pytest.fixture()
def mock_vsite_configs(water_sites_ff):
return ParameterConfig(
cols=["distance"],
scales={"distance": 10.0},
limits={"distance": (-1.0, -0.01)},
include=[water_sites_ff.v_sites.keys[0]],
)


@pytest.fixture()
def mock_water_parameter_config(water_sites_ff):
return {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
scales={"epsilon": 10.0, "sigma": 1.0},
limits={"epsilon": (0.0, None), "sigma": (0.0, None)},
include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]],
),
}


@pytest.fixture()
def mock_parameter_configs(mock_ff):
return {
Expand Down Expand Up @@ -286,148 +326,121 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs):
assert values.shape == expected_values.shape
assert torch.allclose(values, expected_values)

def test_regularized_idxs_no_regularization(
self, mock_ff, mock_parameter_configs, mock_attribute_configs
def test_init_vsites(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
trainable = Trainable(
mock_ff,
parameters=mock_parameter_configs,
attributes=mock_attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)

assert len(trainable.regularized_idxs) == 0
assert len(trainable.regularization_weights) == 0
assert trainable._param_types == ["vdW"]
# check we have a vdW parameter for the oxygen, hydrogen and vsite
assert trainable._param_shapes == [(3, 2)]
assert trainable._attr_types == []

def test_regularized_idxs_with_parameter_regularization(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
regularize={"epsilon": 0.01, "sigma": 0.001},
assert trainable._values.shape == (9,)
assert torch.allclose(
trainable._values,
torch.cat(
[
water_sites_ff.potentials_by_type["vdW"].parameters.flatten(),
water_sites_ff.v_sites.parameters.flatten(),
]
),
}
attribute_configs = {}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
)

# vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma
# So we should have 4 regularized values total: 2 epsilons + 2 sigmas
expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)
# check frozen parameters
# vdW params: eps, sig, eps, sig where only first smirks is unfrozen
# vsite params: dist, inplane, outplane where first smirks is unfrozen
expected_unfrozen_ids = torch.tensor([0, 1, 6])
assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all()

# Check the weights match what we configured
# Interleaved: row 0 (eps, sig), row 1 (eps, sig)
expected_weights = torch.tensor(
[0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype
assert torch.allclose(
trainable._clamp_lower,
torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64),
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_with_attribute_regularization(self, mock_ff):
parameter_configs = {}
attribute_configs = {
"vdW": AttributeConfig(
cols=["scale_14", "scale_15"],
regularize={"scale_14": 0.05},
)
}

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
assert torch.allclose(
trainable._clamp_upper,
torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64),
)

# Only scale_14 should be regularized (1 attribute)
expected_idxs = torch.tensor([0], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

expected_weights = torch.tensor(
[0.05], dtype=trainable.regularization_weights.dtype
assert torch.allclose(
trainable._scales,
torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64),
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularized_idxs_with_mixed_regularization(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon", "sigma"],
regularize={"epsilon": 0.02},
include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]],
),
}
attribute_configs = {
"vdW": AttributeConfig(
cols=["scale_14"],
regularize={"scale_14": 0.1},
)
}

def test_to_values_vsites(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)
vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten()
vsite_params = water_sites_ff.v_sites.parameters.flatten()

# Only first vdW parameter row is included, with only epsilon regularized
# Plus scale_14 attribute
expected_idxs = torch.tensor([0, 2], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

# First should be epsilon (0.02), second should be scale_14 (0.1)
expected_weights = torch.tensor(
[0.02, 0.1], dtype=trainable.regularization_weights.dtype
expected_values = torch.tensor(
[
vdw_params[0] * 10, # scale eps
vdw_params[1], # sigma no scale
vsite_params[0] * 10, # scale vsite distance
]
)
assert torch.allclose(trainable.regularization_weights, expected_weights)
values = trainable.to_values()

def test_regularized_idxs_excluded_parameters(self, mock_ff):
parameter_configs = {
"Bonds": ParameterConfig(
cols=["k", "length"],
regularize={"k": 0.01, "length": 0.02},
exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]],
),
}
attribute_configs = {}
assert values.shape == expected_values.shape
assert torch.allclose(values, expected_values)

def test_to_force_field_vsites_no_op(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
ff_initial = copy.deepcopy(water_sites_ff)

trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)

# Only second bond parameter row should be included (first is excluded)
# Both k and length are regularized
expected_idxs = torch.tensor([0, 1], dtype=torch.long)
assert torch.equal(trainable.regularized_idxs, expected_idxs)

expected_weights = torch.tensor(
[0.01, 0.02], dtype=trainable.regularization_weights.dtype
ff = trainable.to_force_field(trainable.to_values())
assert (
ff.potentials_by_type["vdW"].parameters.shape
== ff_initial.potentials_by_type["vdW"].parameters.shape
)
assert torch.allclose(
ff.potentials_by_type["vdW"].parameters,
ff_initial.potentials_by_type["vdW"].parameters,
)
# vsite parameters are not float64 in the initial ff
vsite_initial = smee.utils.tensor_like(
ff_initial.v_sites.parameters, ff.v_sites.parameters
)
assert torch.allclose(trainable.regularization_weights, expected_weights)

def test_regularization_indices_match_unfrozen_values(self, mock_ff):
parameter_configs = {
"vdW": ParameterConfig(
cols=["epsilon"],
regularize={"epsilon": 0.01},
),
}
attribute_configs = {}
assert torch.allclose(
ff.v_sites.parameters,
vsite_initial,
)

def test_to_force_field_clamp_vsites(
self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config
):
trainable = Trainable(
mock_ff,
parameters=parameter_configs,
attributes=attribute_configs,
water_sites_ff,
parameters=mock_water_parameter_config,
attributes={},
vsites=mock_vsite_configs,
)

values = trainable.to_values()

# Regularization indices should be valid indices into the unfrozen values
assert trainable.regularized_idxs.max() < len(values)
assert trainable.regularized_idxs.min() >= 0

# We should be able to index the values tensor with regularization indices
regularized_values = values[trainable.regularized_idxs]
assert len(regularized_values) == len(trainable.regularized_idxs)
values = trainable.to_values().detach()
# set the distance to outside the clamp region
values[-1] = 0.0
ff = trainable.to_force_field(values)
assert torch.allclose(
ff.v_sites.parameters[0],
torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64),
)
Loading
Loading