From 438e5d23c177b050233a644d643fa11141e46da0 Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Fri, 30 Aug 2024 10:18:49 +0100 Subject: [PATCH 1/8] include vsite fitting in the trainer --- descent/tests/test_train.py | 159 ++++++++++++++++++++++++++++++++++++ descent/train.py | 114 +++++++++++++++++++++++--- 2 files changed, 260 insertions(+), 13 deletions(-) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 3884a55..d2e39ee 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -7,6 +7,7 @@ import pytest import smee import smee.converters +import smee.utils import torch from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey @@ -46,6 +47,45 @@ def mock_ff() -> smee.TensorForceField: return ff +@pytest.fixture() +def water_sites_ff(): + interachange = openff.interchange.Interchange.from_smirnoff( + openff.toolkit.ForceField("tip4p_fb.offxml"), + openff.toolkit.Molecule.from_smiles("O").to_topology(), + ) + ff, _ = smee.converters.convert_interchange(interachange) + # make sure we have vsites in the force field + assert ff.v_sites is not None + # this is awkward to specify in the yaml config file can we make it easier? + expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"] + vsite_ids = [key.id for key in ff.v_sites.keys] + assert vsite_ids == expected_ids + + return ff + + +@pytest.fixture() +def mock_vsite_configs(water_sites_ff): + return ParameterConfig( + cols=["distance"], + scales={"distance": 10.0}, + limits={"distance": (-1.0, -0.01)}, + include=[water_sites_ff.v_sites.keys[0]], + ) + + +@pytest.fixture() +def mock_water_parameter_config(water_sites_ff): + return { + "vdW": ParameterConfig( + cols=["epsilon", "sigma"], + scales={"epsilon": 10.0, "sigma": 1.0}, + limits={"epsilon": (0.0, None), "sigma": (0.0, None)}, + include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]], + ), + } + + @pytest.fixture() def mock_parameter_configs(mock_ff): return { @@ -268,3 +308,122 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs): assert values.shape == expected_values.shape assert torch.allclose(values, expected_values) + + def test_init_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + + assert trainable._param_types == ["vdW"] + # check we have a vdW parameter for the oxygen, hydrogen and vsite + assert trainable._param_shapes == [(3, 2)] + assert trainable._attr_types == [] + + assert trainable._values.shape == (9,) + assert torch.allclose( + trainable._values, + torch.cat( + [ + water_sites_ff.potentials_by_type["vdW"].parameters.flatten(), + water_sites_ff.v_sites.parameters.flatten(), + ] + ), + ) + + # check frozen parameters + # vdW params: eps, sig, eps, sig where only first smirks is unfrozen + # vsite params: dist, inplane, outplane where first smirks is unfrozen + expected_unfrozen_ids = torch.tensor([0, 1, 6]) + assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all() + + assert torch.allclose( + trainable._clamp_lower, + torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64), + ) + assert torch.allclose( + trainable._clamp_upper, + torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64), + ) + assert torch.allclose( + trainable._scales, + torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64), + ) + + def test_to_values_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten() + vsite_params = water_sites_ff.v_sites.parameters.flatten() + + expected_values = torch.tensor( + [ + vdw_params[0] * 10, # scale eps + vdw_params[1], # sigma no scale + vsite_params[0] * 10, # scale vsite distance + ] + ) + values = trainable.to_values() + + assert values.shape == expected_values.shape + assert torch.allclose(values, expected_values) + + def test_to_force_field_vsites_no_op( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + ff_initial = copy.deepcopy(water_sites_ff) + + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + + ff = trainable.to_force_field(trainable.to_values()) + assert ( + ff.potentials_by_type["vdW"].parameters.shape + == ff_initial.potentials_by_type["vdW"].parameters.shape + ) + assert torch.allclose( + ff.potentials_by_type["vdW"].parameters, + ff_initial.potentials_by_type["vdW"].parameters, + ) + # vsite parameters are not float64 in the initial ff + vsite_initial = smee.utils.tensor_like( + ff_initial.v_sites.parameters, ff.v_sites.parameters + ) + + assert torch.allclose( + ff.v_sites.parameters, + vsite_initial, + ) + + def test_to_force_field_clamp_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, + ) + + values = trainable.to_values().detach() + # set the distance to outside the clamp region + values[-1] = 0.0 + ff = trainable.to_force_field(values) + assert torch.allclose( + ff.v_sites.parameters[0], + torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64), + ) diff --git a/descent/train.py b/descent/train.py index 0c7d938..142b52d 100644 --- a/descent/train.py +++ b/descent/train.py @@ -251,11 +251,68 @@ def _prepare( smee.utils.tensor_like(clamp_upper, values), ) + def _prepare_vsites( + self, force_field: smee.TensorForceField, config: ParameterConfig + ): + """ + Prepare the vsite parameters for optimisation. + + Args: + force_field: The tensor force field with parameters + which should be optimised. + config: The config of the parameters to train. + + Returns: + + """ + vsite_parameters = force_field.v_sites.parameters.detach().clone() + n_rows = vsite_parameters.shape[0] + vsite_parameters_flat = vsite_parameters.flatten() + # define the cols as they are not on the tensor model + vsite_cols = ["distance", "inPlaneAngle", "outOfPlaneAngle"] + + all_keys = [_PotentialKey(**key.dict()) for key in force_field.v_sites.keys] + excluded_keys = config.exclude or [] + unfrozen_keys = config.include or all_keys + + key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} + assert len(key_to_row) == len(all_keys), "duplicate keys found" + + unfrozen_rows = { + key_to_row[key] for key in unfrozen_keys if key not in excluded_keys + } + + unfrozen_idxs = [ + col_idx + row_idx * vsite_parameters.shape[1] + for row_idx in range(n_rows) + if row_idx in unfrozen_rows + # the vsite model has no parameter cols so define here + for col_idx, col in enumerate(vsite_cols) + if col in config.cols + ] + vsite_scales = [config.scales.get(col, 1.0) for col in vsite_cols] * n_rows + clamp_lower = [ + config.limits.get(col, (None, None))[0] for col in vsite_cols + ] * n_rows + clamp_lower = [-torch.inf if x is None else x for x in clamp_lower] + clamp_upper = [ + config.limits.get(col, (None, None))[1] for col in vsite_cols + ] * n_rows + clamp_upper = [torch.inf if x is None else x for x in clamp_upper] + return ( + vsite_parameters_flat, + torch.tensor(unfrozen_idxs), + smee.utils.tensor_like(vsite_scales, vsite_parameters), + smee.utils.tensor_like(clamp_lower, vsite_parameters), + smee.utils.tensor_like(clamp_upper, vsite_parameters), + ) + def __init__( self, force_field: smee.TensorForceField, parameters: dict[str, ParameterConfig], attributes: dict[str, AttributeConfig], + vsites: ParameterConfig | None = None, ): """ @@ -263,8 +320,10 @@ def __init__( force_field: The force field to wrap. parameters: Configure which parameters to train. attributes: Configure which attributes to train. + vsites: Configure which vsite parameters to train. """ self._force_field = force_field + self._fit_vsites = False ( self._param_types, @@ -285,20 +344,38 @@ def __init__( attr_clamp_upper, ) = self._prepare(force_field, attributes, "attributes") - self._values = torch.cat([param_values, attr_values]) + values = [param_values, attr_values] + unfrozen_idxs = [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] + scales = [param_scales, attr_scales] + clamp_lower = [param_clamp_lower, attr_clamp_lower] + clamp_upper = [param_clamp_upper, attr_clamp_upper] + + if vsites is not None: + ( + vsite_values, + vsite_unfrozen_idxs, + vsite_scales, + vsite_clamp_lower, + vsite_clamp_upper, + ) = self._prepare_vsites(force_field, vsites) + self._fit_vsites = True + + values.append(vsite_values) + unfrozen_idxs.append( + (vsite_unfrozen_idxs + len(param_scales) + len(attr_scales)) + ) + scales.append(vsite_scales) + clamp_lower.append(vsite_clamp_lower) + clamp_upper.append(vsite_clamp_upper) - self._unfrozen_idxs = torch.cat( - [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] - ).long() + self._values = torch.cat(values) - self._scales = torch.cat([param_scales, attr_scales])[self._unfrozen_idxs] + self._unfrozen_idxs = torch.cat(unfrozen_idxs).long() - self._clamp_lower = torch.cat([param_clamp_lower, attr_clamp_lower])[ - self._unfrozen_idxs - ] - self._clamp_upper = torch.cat([param_clamp_upper, attr_clamp_upper])[ - self._unfrozen_idxs - ] + self._scales = torch.cat(scales)[self._unfrozen_idxs] + + self._clamp_lower = torch.cat(clamp_lower)[self._unfrozen_idxs] + self._clamp_upper = torch.cat(clamp_upper)[self._unfrozen_idxs] @torch.no_grad() def to_values(self) -> torch.Tensor: @@ -320,18 +397,29 @@ def to_force_field(self, values_flat: torch.Tensor) -> smee.TensorForceField: values[self._unfrozen_idxs] = (values_flat / self._scales).clamp( min=self._clamp_lower, max=self._clamp_upper ) - values = _unflatten_tensors(values, self._param_shapes + self._attr_shapes) + shapes = self._param_shapes + self._attr_shapes + + if self._fit_vsites: + shapes.append(self._force_field.v_sites.parameters.shape) + + values = _unflatten_tensors(values, shapes) params = values[: len(self._param_shapes)] for potential_type, param in zip(self._param_types, params, strict=True): potentials[potential_type].parameters = param - attrs = values[len(self._param_shapes) :] + attrs = values[ + len(self._param_shapes) : len(self._param_shapes) + len(self._attr_shapes) + ] for potential_type, attr in zip(self._attr_types, attrs, strict=True): potentials[potential_type].attributes = attr + if self._fit_vsites: + vsite_params = values[len(self._param_shapes) + len(self._attr_shapes) :] + self._force_field.v_sites.parameters = vsite_params[0] + return self._force_field @torch.no_grad() From 58deeec7ddf6ca90780ff4090f0847427af98524 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Mon, 16 Mar 2026 16:13:14 +0000 Subject: [PATCH 2/8] Allow regularization of vsite parameters --- descent/tests/test_train.py | 24 ++++++++++++++++ descent/train.py | 57 ++++++++++++++++++++++++++++--------- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 1aae364..8ef734a 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -444,3 +444,27 @@ def test_to_force_field_clamp_vsites( ff.v_sites.parameters[0], torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64), ) + + def test_init_vsites_regularization( + self, water_sites_ff, mock_water_parameter_config + ): + vsite_config = ParameterConfig( + cols=["distance"], + scales={"distance": 10.0}, + limits={"distance": (-1.0, -0.01)}, + regularize={"distance": 0.25}, + include=[water_sites_ff.v_sites.keys[0]], + ) + + trainable = Trainable( + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=vsite_config, + ) + + assert torch.equal(trainable.regularized_idxs, torch.tensor([2])) + assert torch.allclose( + trainable.regularization_weights, + torch.tensor([0.25], dtype=torch.float64), + ) diff --git a/descent/train.py b/descent/train.py index 4a2ee59..efd7475 100644 --- a/descent/train.py +++ b/descent/train.py @@ -357,14 +357,26 @@ def _prepare_vsites( key_to_row[key] for key in unfrozen_keys if key not in excluded_keys } - unfrozen_idxs = [ - col_idx + row_idx * vsite_parameters.shape[1] - for row_idx in range(n_rows) - if row_idx in unfrozen_rows + unfrozen_idxs = [] + regularized_idxs = [] + regularization_weights = [] + + for row_idx in range(n_rows): + if row_idx not in unfrozen_rows: + continue + # the vsite model has no parameter cols so define here - for col_idx, col in enumerate(vsite_cols) - if col in config.cols - ] + for col_idx, col in enumerate(vsite_cols): + if col not in config.cols: + continue + + flat_idx = col_idx + row_idx * vsite_parameters.shape[1] + unfrozen_idxs.append(flat_idx) + + if col in config.regularize: + regularized_idxs.append(flat_idx) + regularization_weights.append(config.regularize[col]) + vsite_scales = [config.scales.get(col, 1.0) for col in vsite_cols] * n_rows clamp_lower = [ config.limits.get(col, (None, None))[0] for col in vsite_cols @@ -380,6 +392,12 @@ def _prepare_vsites( smee.utils.tensor_like(vsite_scales, vsite_parameters), smee.utils.tensor_like(clamp_lower, vsite_parameters), smee.utils.tensor_like(clamp_upper, vsite_parameters), + torch.tensor(regularized_idxs), + ( + smee.utils.tensor_like(regularization_weights, vsite_parameters) + if regularization_weights + else smee.utils.tensor_like([], vsite_parameters) + ), ) def __init__( @@ -436,6 +454,8 @@ def __init__( vsite_scales, vsite_clamp_lower, vsite_clamp_upper, + vsite_regularized_idxs, + vsite_regularization_weights, ) = self._prepare_vsites(force_field, vsites) self._fit_vsites = True @@ -457,12 +477,23 @@ def __init__( self._clamp_upper = torch.cat(clamp_upper)[self._unfrozen_idxs] # Store regularization information - all_regularized_idxs = torch.cat( - [param_regularized_idxs, attr_regularized_idxs + len(param_scales)] - ).long() - all_regularization_weights = torch.cat( - [param_regularization_weights, attr_regularization_weights] - ) + all_regularized_idxs = [ + param_regularized_idxs, + attr_regularized_idxs + len(param_scales), + ] + all_regularization_weights = [ + param_regularization_weights, + attr_regularization_weights, + ] + + if self._fit_vsites: + all_regularized_idxs.append( + vsite_regularized_idxs + len(param_scales) + len(attr_scales) + ) + all_regularization_weights.append(vsite_regularization_weights) + + all_regularized_idxs = torch.cat(all_regularized_idxs).long() + all_regularization_weights = torch.cat(all_regularization_weights) # Map global indices to unfrozen indices idx_mapping = {idx.item(): i for i, idx in enumerate(self._unfrozen_idxs)} From 8246eca670ec6107960c8b662b24720f45194876 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Mon, 16 Mar 2026 16:14:28 +0000 Subject: [PATCH 3/8] Fix type: interachange --- descent/tests/test_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 8ef734a..a44c9fb 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -49,11 +49,11 @@ def mock_ff() -> smee.TensorForceField: @pytest.fixture() def water_sites_ff(): - interachange = openff.interchange.Interchange.from_smirnoff( + interchange = openff.interchange.Interchange.from_smirnoff( openff.toolkit.ForceField("tip4p_fb.offxml"), openff.toolkit.Molecule.from_smiles("O").to_topology(), ) - ff, _ = smee.converters.convert_interchange(interachange) + ff, _ = smee.converters.convert_interchange(interchange) # make sure we have vsites in the force field assert ff.v_sites is not None # this is awkward to specify in the yaml config file can we make it easier? From 9a7744c331410eee18587435651d9aa9b2b3cd04 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Mon, 16 Mar 2026 16:16:19 +0000 Subject: [PATCH 4/8] Avoid hard-coding vsite parameter col names --- descent/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/descent/train.py b/descent/train.py index efd7475..f55c5c6 100644 --- a/descent/train.py +++ b/descent/train.py @@ -343,8 +343,9 @@ def _prepare_vsites( vsite_parameters = force_field.v_sites.parameters.detach().clone() n_rows = vsite_parameters.shape[0] vsite_parameters_flat = vsite_parameters.flatten() - # define the cols as they are not on the tensor model - vsite_cols = ["distance", "inPlaneAngle", "outOfPlaneAngle"] + # Getting the column names from this dict is a bit awkward but + # avoids hard-coding them. + vsite_cols = list(force_field.v_sites.default_units().keys()) all_keys = [_PotentialKey(**key.dict()) for key in force_field.v_sites.keys] excluded_keys = config.exclude or [] From e05bc9fd12f0acbf2f18535042bfde602469c437 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Mon, 16 Mar 2026 16:46:26 +0000 Subject: [PATCH 5/8] Reduce duplication between preparation with/ without vsites --- descent/train.py | 201 ++++++++++++++++++++++++++--------------------- 1 file changed, 111 insertions(+), 90 deletions(-) diff --git a/descent/train.py b/descent/train.py index f55c5c6..cf96e7f 100644 --- a/descent/train.py +++ b/descent/train.py @@ -196,6 +196,80 @@ class Trainable: parameters so they are not updated during training. """ + @staticmethod + def _prepare_rows( + config: AttributeConfig, + n_rows: int, + all_keys: list[_PotentialKey] | None = None, + ) -> set[int]: + """Determine which rows should be trainable.""" + if not isinstance(config, ParameterConfig): + return set(range(n_rows)) + + assert all_keys is not None + + excluded_keys = config.exclude or [] + unfrozen_keys = config.include or all_keys + + key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} + assert len(key_to_row) == len(all_keys), "duplicate keys found" + + return {key_to_row[key] for key in unfrozen_keys if key not in excluded_keys} + + @staticmethod + def _prepare_values( + values: torch.Tensor, + cols: list[str] | tuple[str, ...], + config: AttributeConfig, + unfrozen_rows: set[int], + n_rows: int, + idx_offset: int = 0, + ) -> tuple[ + list[int], + list[float], + list[float], + list[float], + list[int], + list[float], + ]: + """Prepare unfrozen indices and transform metadata for a value block.""" + row_width = values.shape[-1] + + unfrozen_idxs = [] + scales = [] + clamp_lower = [] + clamp_upper = [] + regularized_idxs = [] + regularization_weights = [] + + for row_idx in range(n_rows): + for col_idx, col in enumerate(cols): + flat_idx = idx_offset + col_idx + row_idx * row_width + + scales.append(config.scales.get(col, 1.0)) + + lower, upper = config.limits.get(col, (None, None)) + clamp_lower.append(-torch.inf if lower is None else lower) + clamp_upper.append(torch.inf if upper is None else upper) + + if row_idx not in unfrozen_rows or col not in config.cols: + continue + + unfrozen_idxs.append(flat_idx) + + if col in config.regularize: + regularized_idxs.append(flat_idx) + regularization_weights.append(config.regularize[col]) + + return ( + unfrozen_idxs, + scales, + clamp_lower, + clamp_upper, + regularized_idxs, + regularization_weights, + ) + @classmethod def _prepare( cls, @@ -241,68 +315,38 @@ def _prepare( n_rows = 1 if attr == "attributes" else len(potential_values) - unfrozen_rows = set(range(n_rows)) - + all_keys = None if isinstance(potential_config, ParameterConfig): all_keys = [ _PotentialKey(**v.dict()) for v in getattr(potential, f"{attr[:-1]}_keys") ] - excluded_keys = potential_config.exclude or [] - unfrozen_keys = potential_config.include or all_keys - - key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} - assert len(key_to_row) == len(all_keys), "duplicate keys found" - - unfrozen_rows = { - key_to_row[key] for key in unfrozen_keys if key not in excluded_keys - } - - # Track unfrozen and regularized indices - for row_idx in range(n_rows): - if row_idx not in unfrozen_rows: - continue - for col_idx, col in enumerate(potential_cols): - if col not in potential_config.cols: - continue - - flat_idx = ( - unfrozen_col_offset - + col_idx - + row_idx * potential_values.shape[-1] - ) - unfrozen_idxs.append(flat_idx) - - if col in potential_config.regularize: - regularized_idxs.append(flat_idx) - regularization_weights.append(potential_config.regularize[col]) - - unfrozen_col_offset += len(potential_values_flat) - - potential_scales = [ - potential_config.scales.get(col, 1.0) for col in potential_cols - ] * n_rows + unfrozen_rows = cls._prepare_rows(potential_config, n_rows, all_keys) + ( + potential_unfrozen_idxs, + potential_scales, + potential_clamp_lower, + potential_clamp_upper, + potential_regularized_idxs, + potential_regularization_weights, + ) = cls._prepare_values( + potential_values, + potential_cols, + potential_config, + unfrozen_rows, + n_rows, + unfrozen_col_offset, + ) + unfrozen_idxs.extend(potential_unfrozen_idxs) scales.extend(potential_scales) - - potential_clamp_lower = [ - potential_config.limits.get(col, (None, None))[0] - for col in potential_cols - ] * n_rows - potential_clamp_lower = [ - -torch.inf if x is None else x for x in potential_clamp_lower - ] clamp_lower.extend(potential_clamp_lower) - - potential_clamp_upper = [ - potential_config.limits.get(col, (None, None))[1] - for col in potential_cols - ] * n_rows - potential_clamp_upper = [ - torch.inf if x is None else x for x in potential_clamp_upper - ] clamp_upper.extend(potential_clamp_upper) + regularized_idxs.extend(potential_regularized_idxs) + regularization_weights.extend(potential_regularization_weights) + + unfrozen_col_offset += len(potential_values_flat) values = ( smee.utils.tensor_like([], force_field.potentials[0].parameters) @@ -348,45 +392,22 @@ def _prepare_vsites( vsite_cols = list(force_field.v_sites.default_units().keys()) all_keys = [_PotentialKey(**key.dict()) for key in force_field.v_sites.keys] - excluded_keys = config.exclude or [] - unfrozen_keys = config.include or all_keys - - key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} - assert len(key_to_row) == len(all_keys), "duplicate keys found" - - unfrozen_rows = { - key_to_row[key] for key in unfrozen_keys if key not in excluded_keys - } - - unfrozen_idxs = [] - regularized_idxs = [] - regularization_weights = [] - - for row_idx in range(n_rows): - if row_idx not in unfrozen_rows: - continue - - # the vsite model has no parameter cols so define here - for col_idx, col in enumerate(vsite_cols): - if col not in config.cols: - continue - - flat_idx = col_idx + row_idx * vsite_parameters.shape[1] - unfrozen_idxs.append(flat_idx) - - if col in config.regularize: - regularized_idxs.append(flat_idx) - regularization_weights.append(config.regularize[col]) + unfrozen_rows = self._prepare_rows(config, n_rows, all_keys) + ( + unfrozen_idxs, + vsite_scales, + clamp_lower, + clamp_upper, + regularized_idxs, + regularization_weights, + ) = self._prepare_values( + vsite_parameters, + vsite_cols, + config, + unfrozen_rows, + n_rows, + ) - vsite_scales = [config.scales.get(col, 1.0) for col in vsite_cols] * n_rows - clamp_lower = [ - config.limits.get(col, (None, None))[0] for col in vsite_cols - ] * n_rows - clamp_lower = [-torch.inf if x is None else x for x in clamp_lower] - clamp_upper = [ - config.limits.get(col, (None, None))[1] for col in vsite_cols - ] * n_rows - clamp_upper = [torch.inf if x is None else x for x in clamp_upper] return ( vsite_parameters_flat, torch.tensor(unfrozen_idxs), From a869d8348ea5fe84961049f3154fefa8094bff44 Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Wed, 25 Mar 2026 10:26:19 +0000 Subject: [PATCH 6/8] Add comment explaining vsites clamp test --- descent/tests/test_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index a44c9fb..2a8dee0 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -10,6 +10,7 @@ import smee.utils import torch + from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey @@ -436,6 +437,11 @@ def test_to_force_field_clamp_vsites( vsites=mock_vsite_configs, ) + # The trainable values are, in order, the vdW parameters (eps, sigma) + # followed by the vsite distance. # When we set the last trainable + # value to 0.0, this corresponds to the vsite distance, which is the first + # parameter in ff.v_sites.parameters. + values = trainable.to_values().detach() # set the distance to outside the clamp region values[-1] = 0.0 From 396e92c37aa78dadaf90a36c57b6a6f8f25e2fed Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Wed, 25 Mar 2026 10:26:35 +0000 Subject: [PATCH 7/8] Update docstring --- descent/train.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/descent/train.py b/descent/train.py index cf96e7f..8fb9e0e 100644 --- a/descent/train.py +++ b/descent/train.py @@ -373,17 +373,7 @@ def _prepare( def _prepare_vsites( self, force_field: smee.TensorForceField, config: ParameterConfig ): - """ - Prepare the vsite parameters for optimisation. - - Args: - force_field: The tensor force field with parameters - which should be optimised. - config: The config of the parameters to train. - - Returns: - - """ + """Prepare the vsite parameters for optimisation.""" vsite_parameters = force_field.v_sites.parameters.detach().clone() n_rows = vsite_parameters.shape[0] vsite_parameters_flat = vsite_parameters.flatten() From 0bfb99ff876502b012795dc9ff73d57f76cf509f Mon Sep 17 00:00:00 2001 From: Finlay Clark Date: Wed, 25 Mar 2026 10:58:39 +0000 Subject: [PATCH 8/8] Don't keep track of frozen value clamp and scales --- descent/train.py | 294 +++++++++++++++++++++++------------------------ 1 file changed, 145 insertions(+), 149 deletions(-) diff --git a/descent/train.py b/descent/train.py index 8fb9e0e..9a5b83e 100644 --- a/descent/train.py +++ b/descent/train.py @@ -26,6 +26,29 @@ def _unflatten_tensors( return tensors +def _tensor_like_or_empty(values: list[float], like: torch.Tensor) -> torch.Tensor: + """Create a tensor like `like`, returning an empty one when no values exist.""" + return ( + smee.utils.tensor_like(values, like) + if values else smee.utils.tensor_like([], like) + ) + + +def _validate_trainable_keys( + cols: list[str], + scales: dict[str, float], + limits: dict[str, tuple[float | None, float | None]], + regularize: dict[str, float], +) -> None: + """Ensure transform dictionaries only reference trainable columns.""" + if any(key not in cols for key in scales): + raise ValueError("cannot scale non-trainable parameters") + if any(key not in cols for key in limits): + raise ValueError("cannot clamp non-trainable parameters") + if any(key not in cols for key in regularize): + raise ValueError("cannot regularize non-trainable parameters") + + if pydantic.__version__.startswith("1."): _PotentialKey = openff.interchange.models.PotentialKey PotentialKeyList = list[_PotentialKey] @@ -62,7 +85,7 @@ def _convert_keys(value: typing.Any) -> typing.Any: if not isinstance(value, list): return value - value = [ + return [ ( _PotentialKey(**v.dict()) if isinstance(v, openff.interchange.models.PotentialKey) @@ -70,7 +93,6 @@ def _convert_keys(value: typing.Any) -> typing.Any: ) for v in value ] - return value PotentialKeyList = typing.Annotated[ list[_PotentialKey], pydantic.BeforeValidator(_convert_keys) @@ -112,12 +134,7 @@ def _validate_keys(cls, values): limits = values.get("limits") regularize = values.get("regularize") - if any(key not in cols for key in scales): - raise ValueError("cannot scale non-trainable parameters") - if any(key not in cols for key in limits): - raise ValueError("cannot clamp non-trainable parameters") - if any(key not in cols for key in regularize): - raise ValueError("cannot regularize non-trainable parameters") + _validate_trainable_keys(cols, scales, limits, regularize) return values @@ -127,14 +144,9 @@ def _validate_keys(cls, values): def _validate_keys(self): """Ensure that the keys in `scales` and `limits` match `cols`.""" - if any(key not in self.cols for key in self.scales): - raise ValueError("cannot scale non-trainable parameters") - - if any(key not in self.cols for key in self.limits): - raise ValueError("cannot clamp non-trainable parameters") - - if any(key not in self.cols for key in self.regularize): - raise ValueError("cannot regularize non-trainable parameters") + _validate_trainable_keys( + self.cols, self.scales, self.limits, self.regularize + ) return self @@ -187,6 +199,25 @@ def _validate_include_exclude(self): return self +class _PreparedBlock(typing.NamedTuple): + """Per-block output of ``_prepare`` and ``_prepare_vsites``. + + All index tensors are zero-based relative to this block's own flat value + tensor. ``__init__`` is responsible for applying global offsets when + combining blocks. + """ + + values: torch.Tensor + shapes: list[torch.Size] + unfrozen_idxs: torch.Tensor + scales: torch.Tensor + clamp_lower: torch.Tensor + clamp_upper: torch.Tensor + # Always a subset of unfrozen_idxs. + regularized_idxs: torch.Tensor + regularization_weights: torch.Tensor + + class Trainable: """A convenient wrapper around a tensor force field that gives greater control over how parameters should be trained. @@ -222,8 +253,6 @@ def _prepare_values( cols: list[str] | tuple[str, ...], config: AttributeConfig, unfrozen_rows: set[int], - n_rows: int, - idx_offset: int = 0, ) -> tuple[ list[int], list[float], @@ -232,8 +261,17 @@ def _prepare_values( list[int], list[float], ]: - """Prepare unfrozen indices and transform metadata for a value block.""" + """Prepare unfrozen indices and transform metadata for a value block. + + Returned indices are zero-based relative to this block's flat value + tensor. The caller is responsible for applying any global offset before + combining indices across blocks. + + Only unfrozen entries are accumulated, so all returned lists are + parallel and require no further post-hoc indexing. + """ row_width = values.shape[-1] + col_to_idx = {col: idx for idx, col in enumerate(cols)} unfrozen_idxs = [] scales = [] @@ -242,21 +280,17 @@ def _prepare_values( regularized_idxs = [] regularization_weights = [] - for row_idx in range(n_rows): - for col_idx, col in enumerate(cols): - flat_idx = idx_offset + col_idx + row_idx * row_width + for row_idx in unfrozen_rows: + for col in config.cols: + flat_idx = col_to_idx[col] + row_idx * row_width + unfrozen_idxs.append(flat_idx) scales.append(config.scales.get(col, 1.0)) lower, upper = config.limits.get(col, (None, None)) clamp_lower.append(-torch.inf if lower is None else lower) clamp_upper.append(torch.inf if upper is None else upper) - if row_idx not in unfrozen_rows or col not in config.cols: - continue - - unfrozen_idxs.append(flat_idx) - if col in config.regularize: regularized_idxs.append(flat_idx) regularization_weights.append(config.regularize[col]) @@ -276,9 +310,13 @@ def _prepare( force_field: smee.TensorForceField, config: dict[str, AttributeConfig], attr: typing.Literal["parameters", "attributes"], - ): + ) -> _PreparedBlock: """Prepare the trainable parameters or attributes for the given force field and - configuration.""" + configuration. + + Returned indices are zero-based relative to the block's own flat value + tensor. + """ potential_types = sorted(config) potentials = [ force_field.potentials_by_type[potential_type] @@ -289,10 +327,9 @@ def _prepare( shapes = [] unfrozen_idxs = [] - unfrozen_col_offset = 0 + idx_offset = 0 scales = [] - clamp_lower = [] clamp_upper = [] @@ -335,45 +372,44 @@ def _prepare( potential_cols, potential_config, unfrozen_rows, - n_rows, - unfrozen_col_offset, ) - unfrozen_idxs.extend(potential_unfrozen_idxs) + unfrozen_idxs.extend(i + idx_offset for i in potential_unfrozen_idxs) scales.extend(potential_scales) clamp_lower.extend(potential_clamp_lower) clamp_upper.extend(potential_clamp_upper) - regularized_idxs.extend(potential_regularized_idxs) + regularized_idxs.extend(i + idx_offset for i in potential_regularized_idxs) regularization_weights.extend(potential_regularization_weights) - unfrozen_col_offset += len(potential_values_flat) + idx_offset += len(potential_values_flat) - values = ( + flat_values = ( smee.utils.tensor_like([], force_field.potentials[0].parameters) if len(values) == 0 else torch.cat(values) ) - return ( - potential_types, - values, - shapes, - torch.tensor(unfrozen_idxs), - smee.utils.tensor_like(scales, values), - smee.utils.tensor_like(clamp_lower, values), - smee.utils.tensor_like(clamp_upper, values), - torch.tensor(regularized_idxs), - ( - smee.utils.tensor_like(regularization_weights, values) - if regularization_weights - else smee.utils.tensor_like([], values) + return _PreparedBlock( + values=flat_values, + shapes=shapes, + unfrozen_idxs=torch.tensor(unfrozen_idxs), + scales=smee.utils.tensor_like(scales, flat_values), + clamp_lower=smee.utils.tensor_like(clamp_lower, flat_values), + clamp_upper=smee.utils.tensor_like(clamp_upper, flat_values), + regularized_idxs=torch.tensor(regularized_idxs), + regularization_weights=_tensor_like_or_empty( + regularization_weights, flat_values ), ) def _prepare_vsites( self, force_field: smee.TensorForceField, config: ParameterConfig - ): - """Prepare the vsite parameters for optimisation.""" + ) -> _PreparedBlock: + """Prepare the vsite parameters for optimisation. + + Returned indices are zero-based relative to the block's own flat value + tensor. + """ vsite_parameters = force_field.v_sites.parameters.detach().clone() n_rows = vsite_parameters.shape[0] vsite_parameters_flat = vsite_parameters.flatten() @@ -395,20 +431,18 @@ def _prepare_vsites( vsite_cols, config, unfrozen_rows, - n_rows, ) - return ( - vsite_parameters_flat, - torch.tensor(unfrozen_idxs), - smee.utils.tensor_like(vsite_scales, vsite_parameters), - smee.utils.tensor_like(clamp_lower, vsite_parameters), - smee.utils.tensor_like(clamp_upper, vsite_parameters), - torch.tensor(regularized_idxs), - ( - smee.utils.tensor_like(regularization_weights, vsite_parameters) - if regularization_weights - else smee.utils.tensor_like([], vsite_parameters) + return _PreparedBlock( + values=vsite_parameters_flat, + shapes=[vsite_parameters.shape], + unfrozen_idxs=torch.tensor(unfrozen_idxs), + scales=smee.utils.tensor_like(vsite_scales, vsite_parameters), + clamp_lower=smee.utils.tensor_like(clamp_lower, vsite_parameters), + clamp_upper=smee.utils.tensor_like(clamp_upper, vsite_parameters), + regularized_idxs=torch.tensor(regularized_idxs), + regularization_weights=_tensor_like_or_empty( + regularization_weights, vsite_parameters ), ) @@ -430,100 +464,62 @@ def __init__( self._force_field = force_field self._fit_vsites = False - ( - self._param_types, - param_values, - self._param_shapes, - param_unfrozen_idxs, - param_scales, - param_clamp_lower, - param_clamp_upper, - param_regularized_idxs, - param_regularization_weights, - ) = self._prepare(force_field, parameters, "parameters") - ( - self._attr_types, - attr_values, - self._attr_shapes, - attr_unfrozen_idxs, - attr_scales, - attr_clamp_lower, - attr_clamp_upper, - attr_regularized_idxs, - attr_regularization_weights, - ) = self._prepare(force_field, attributes, "attributes") - - values = [param_values, attr_values] - unfrozen_idxs = [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] - scales = [param_scales, attr_scales] - clamp_lower = [param_clamp_lower, attr_clamp_lower] - clamp_upper = [param_clamp_upper, attr_clamp_upper] + param_block = self._prepare(force_field, parameters, "parameters") + attr_block = self._prepare(force_field, attributes, "attributes") - if vsites is not None: - ( - vsite_values, - vsite_unfrozen_idxs, - vsite_scales, - vsite_clamp_lower, - vsite_clamp_upper, - vsite_regularized_idxs, - vsite_regularization_weights, - ) = self._prepare_vsites(force_field, vsites) - self._fit_vsites = True + self._param_types = sorted(parameters) + self._param_shapes = param_block.shapes + self._attr_types = sorted(attributes) + self._attr_shapes = attr_block.shapes - values.append(vsite_values) - unfrozen_idxs.append( - (vsite_unfrozen_idxs + len(param_scales) + len(attr_scales)) - ) - scales.append(vsite_scales) - clamp_lower.append(vsite_clamp_lower) - clamp_upper.append(vsite_clamp_upper) + blocks: list[_PreparedBlock] = [param_block, attr_block] - self._values = torch.cat(values) - - self._unfrozen_idxs = torch.cat(unfrozen_idxs).long() - - self._scales = torch.cat(scales)[self._unfrozen_idxs] - - self._clamp_lower = torch.cat(clamp_lower)[self._unfrozen_idxs] - self._clamp_upper = torch.cat(clamp_upper)[self._unfrozen_idxs] - - # Store regularization information - all_regularized_idxs = [ - param_regularized_idxs, - attr_regularized_idxs + len(param_scales), - ] - all_regularization_weights = [ - param_regularization_weights, - attr_regularization_weights, - ] - - if self._fit_vsites: - all_regularized_idxs.append( - vsite_regularized_idxs + len(param_scales) + len(attr_scales) - ) - all_regularization_weights.append(vsite_regularization_weights) + if vsites is not None: + blocks.append(self._prepare_vsites(force_field, vsites)) + self._fit_vsites = True - all_regularized_idxs = torch.cat(all_regularized_idxs).long() - all_regularization_weights = torch.cat(all_regularization_weights) + # Each block's indices are zero-based within that block. Apply the + # running value-tensor offset at the point of combination so that the + # offset arithmetic is in one place and adding further blocks requires + # no changes beyond appending to `blocks`. + all_values = [] + all_unfrozen_idxs = [] + all_scales = [] + all_clamp_lower = [] + all_clamp_upper = [] + all_regularized_idxs = [] + all_regularization_weights = [] + + offset = 0 + for block in blocks: + all_values.append(block.values) + all_unfrozen_idxs.append(block.unfrozen_idxs + offset) + all_scales.append(block.scales) + all_clamp_lower.append(block.clamp_lower) + all_clamp_upper.append(block.clamp_upper) + all_regularized_idxs.append(block.regularized_idxs + offset) + all_regularization_weights.append(block.regularization_weights) + offset += len(block.values) + + self._values = torch.cat(all_values) + self._unfrozen_idxs = torch.cat(all_unfrozen_idxs).long() + self._scales = torch.cat(all_scales) + self._clamp_lower = torch.cat(all_clamp_lower) + self._clamp_upper = torch.cat(all_clamp_upper) + + # Map global flat indices -> positions within the unfrozen vector. + # regularized_idxs are guaranteed to be a subset of unfrozen_idxs, + # so no missing-key guard is needed. + combined_regularized_idxs = torch.cat(all_regularized_idxs).long() + combined_regularization_weights = torch.cat(all_regularization_weights) - # Map global indices to unfrozen indices idx_mapping = {idx.item(): i for i, idx in enumerate(self._unfrozen_idxs)} self._regularized_idxs = torch.tensor( - [ - idx_mapping[idx.item()] - for idx in all_regularized_idxs - if idx.item() in idx_mapping - ] + [idx_mapping[idx.item()] for idx in combined_regularized_idxs] ).long() - regularization_weights = [ - all_regularization_weights[i] - for i, idx in enumerate(all_regularized_idxs) - if idx.item() in idx_mapping - ] self._regularization_weights = ( - torch.stack(regularization_weights) - if regularization_weights + combined_regularization_weights + if len(combined_regularization_weights) > 0 else torch.tensor([]) )