diff --git a/descent/tests/test_train.py b/descent/tests/test_train.py index 1b445ee..2a8dee0 100644 --- a/descent/tests/test_train.py +++ b/descent/tests/test_train.py @@ -7,8 +7,10 @@ import pytest import smee import smee.converters +import smee.utils import torch + from descent.train import AttributeConfig, ParameterConfig, Trainable, _PotentialKey @@ -46,6 +48,45 @@ def mock_ff() -> smee.TensorForceField: return ff +@pytest.fixture() +def water_sites_ff(): + interchange = openff.interchange.Interchange.from_smirnoff( + openff.toolkit.ForceField("tip4p_fb.offxml"), + openff.toolkit.Molecule.from_smiles("O").to_topology(), + ) + ff, _ = smee.converters.convert_interchange(interchange) + # make sure we have vsites in the force field + assert ff.v_sites is not None + # this is awkward to specify in the yaml config file can we make it easier? + expected_ids = ["[#1:2]-[#8X2H2+0:1]-[#1:3] EP once"] + vsite_ids = [key.id for key in ff.v_sites.keys] + assert vsite_ids == expected_ids + + return ff + + +@pytest.fixture() +def mock_vsite_configs(water_sites_ff): + return ParameterConfig( + cols=["distance"], + scales={"distance": 10.0}, + limits={"distance": (-1.0, -0.01)}, + include=[water_sites_ff.v_sites.keys[0]], + ) + + +@pytest.fixture() +def mock_water_parameter_config(water_sites_ff): + return { + "vdW": ParameterConfig( + cols=["epsilon", "sigma"], + scales={"epsilon": 10.0, "sigma": 1.0}, + limits={"epsilon": (0.0, None), "sigma": (0.0, None)}, + include=[water_sites_ff.potentials_by_type["vdW"].parameter_keys[0]], + ), + } + + @pytest.fixture() def mock_parameter_configs(mock_ff): return { @@ -286,148 +327,150 @@ def test_clamp(self, mock_ff, mock_parameter_configs, mock_attribute_configs): assert values.shape == expected_values.shape assert torch.allclose(values, expected_values) - def test_regularized_idxs_no_regularization( - self, mock_ff, mock_parameter_configs, mock_attribute_configs + def test_init_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config ): trainable = Trainable( - mock_ff, - parameters=mock_parameter_configs, - attributes=mock_attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) - assert len(trainable.regularized_idxs) == 0 - assert len(trainable.regularization_weights) == 0 + assert trainable._param_types == ["vdW"] + # check we have a vdW parameter for the oxygen, hydrogen and vsite + assert trainable._param_shapes == [(3, 2)] + assert trainable._attr_types == [] - def test_regularized_idxs_with_parameter_regularization(self, mock_ff): - parameter_configs = { - "vdW": ParameterConfig( - cols=["epsilon", "sigma"], - regularize={"epsilon": 0.01, "sigma": 0.001}, + assert trainable._values.shape == (9,) + assert torch.allclose( + trainable._values, + torch.cat( + [ + water_sites_ff.potentials_by_type["vdW"].parameters.flatten(), + water_sites_ff.v_sites.parameters.flatten(), + ] ), - } - attribute_configs = {} - - trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, ) - # vdW has 2 parameters (2 rows), and we're regularizing both epsilon and sigma - # So we should have 4 regularized values total: 2 epsilons + 2 sigmas - expected_idxs = torch.tensor([0, 1, 2, 3], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) + # check frozen parameters + # vdW params: eps, sig, eps, sig where only first smirks is unfrozen + # vsite params: dist, inplane, outplane where first smirks is unfrozen + expected_unfrozen_ids = torch.tensor([0, 1, 6]) + assert (trainable._unfrozen_idxs == expected_unfrozen_ids).all() - # Check the weights match what we configured - # Interleaved: row 0 (eps, sig), row 1 (eps, sig) - expected_weights = torch.tensor( - [0.01, 0.001, 0.01, 0.001], dtype=trainable.regularization_weights.dtype + assert torch.allclose( + trainable._clamp_lower, + torch.tensor([0.0, 0.0, -1.0], dtype=torch.float64), + ) + assert torch.allclose( + trainable._clamp_upper, + torch.tensor([torch.inf, torch.inf, -0.01], dtype=torch.float64), + ) + assert torch.allclose( + trainable._scales, + torch.tensor([10.0, 1.0, 10.0], dtype=torch.float64), ) - assert torch.allclose(trainable.regularization_weights, expected_weights) - - def test_regularized_idxs_with_attribute_regularization(self, mock_ff): - parameter_configs = {} - attribute_configs = { - "vdW": AttributeConfig( - cols=["scale_14", "scale_15"], - regularize={"scale_14": 0.05}, - ) - } + def test_to_values_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) + vdw_params = water_sites_ff.potentials_by_type["vdW"].parameters.flatten() + vsite_params = water_sites_ff.v_sites.parameters.flatten() - # Only scale_14 should be regularized (1 attribute) - expected_idxs = torch.tensor([0], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) - - expected_weights = torch.tensor( - [0.05], dtype=trainable.regularization_weights.dtype + expected_values = torch.tensor( + [ + vdw_params[0] * 10, # scale eps + vdw_params[1], # sigma no scale + vsite_params[0] * 10, # scale vsite distance + ] ) - assert torch.allclose(trainable.regularization_weights, expected_weights) + values = trainable.to_values() - def test_regularized_idxs_with_mixed_regularization(self, mock_ff): - parameter_configs = { - "vdW": ParameterConfig( - cols=["epsilon", "sigma"], - regularize={"epsilon": 0.02}, - include=[mock_ff.potentials_by_type["vdW"].parameter_keys[0]], - ), - } - attribute_configs = { - "vdW": AttributeConfig( - cols=["scale_14"], - regularize={"scale_14": 0.1}, - ) - } + assert values.shape == expected_values.shape + assert torch.allclose(values, expected_values) + + def test_to_force_field_vsites_no_op( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): + ff_initial = copy.deepcopy(water_sites_ff) trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) - # Only first vdW parameter row is included, with only epsilon regularized - # Plus scale_14 attribute - expected_idxs = torch.tensor([0, 2], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) - - # First should be epsilon (0.02), second should be scale_14 (0.1) - expected_weights = torch.tensor( - [0.02, 0.1], dtype=trainable.regularization_weights.dtype + ff = trainable.to_force_field(trainable.to_values()) + assert ( + ff.potentials_by_type["vdW"].parameters.shape + == ff_initial.potentials_by_type["vdW"].parameters.shape + ) + assert torch.allclose( + ff.potentials_by_type["vdW"].parameters, + ff_initial.potentials_by_type["vdW"].parameters, + ) + # vsite parameters are not float64 in the initial ff + vsite_initial = smee.utils.tensor_like( + ff_initial.v_sites.parameters, ff.v_sites.parameters ) - assert torch.allclose(trainable.regularization_weights, expected_weights) - def test_regularized_idxs_excluded_parameters(self, mock_ff): - parameter_configs = { - "Bonds": ParameterConfig( - cols=["k", "length"], - regularize={"k": 0.01, "length": 0.02}, - exclude=[mock_ff.potentials_by_type["Bonds"].parameter_keys[0]], - ), - } - attribute_configs = {} + assert torch.allclose( + ff.v_sites.parameters, + vsite_initial, + ) + def test_to_force_field_clamp_vsites( + self, water_sites_ff, mock_vsite_configs, mock_water_parameter_config + ): trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=mock_vsite_configs, ) - # Only second bond parameter row should be included (first is excluded) - # Both k and length are regularized - expected_idxs = torch.tensor([0, 1], dtype=torch.long) - assert torch.equal(trainable.regularized_idxs, expected_idxs) + # The trainable values are, in order, the vdW parameters (eps, sigma) + # followed by the vsite distance. # When we set the last trainable + # value to 0.0, this corresponds to the vsite distance, which is the first + # parameter in ff.v_sites.parameters. - expected_weights = torch.tensor( - [0.01, 0.02], dtype=trainable.regularization_weights.dtype + values = trainable.to_values().detach() + # set the distance to outside the clamp region + values[-1] = 0.0 + ff = trainable.to_force_field(values) + assert torch.allclose( + ff.v_sites.parameters[0], + torch.tensor([-0.0100, 3.1416, 0.0000], dtype=torch.float64), ) - assert torch.allclose(trainable.regularization_weights, expected_weights) - def test_regularization_indices_match_unfrozen_values(self, mock_ff): - parameter_configs = { - "vdW": ParameterConfig( - cols=["epsilon"], - regularize={"epsilon": 0.01}, - ), - } - attribute_configs = {} + def test_init_vsites_regularization( + self, water_sites_ff, mock_water_parameter_config + ): + vsite_config = ParameterConfig( + cols=["distance"], + scales={"distance": 10.0}, + limits={"distance": (-1.0, -0.01)}, + regularize={"distance": 0.25}, + include=[water_sites_ff.v_sites.keys[0]], + ) trainable = Trainable( - mock_ff, - parameters=parameter_configs, - attributes=attribute_configs, + water_sites_ff, + parameters=mock_water_parameter_config, + attributes={}, + vsites=vsite_config, ) - values = trainable.to_values() - - # Regularization indices should be valid indices into the unfrozen values - assert trainable.regularized_idxs.max() < len(values) - assert trainable.regularized_idxs.min() >= 0 - - # We should be able to index the values tensor with regularization indices - regularized_values = values[trainable.regularized_idxs] - assert len(regularized_values) == len(trainable.regularized_idxs) + assert torch.equal(trainable.regularized_idxs, torch.tensor([2])) + assert torch.allclose( + trainable.regularization_weights, + torch.tensor([0.25], dtype=torch.float64), + ) diff --git a/descent/train.py b/descent/train.py index 1dcb40f..9a5b83e 100644 --- a/descent/train.py +++ b/descent/train.py @@ -26,6 +26,29 @@ def _unflatten_tensors( return tensors +def _tensor_like_or_empty(values: list[float], like: torch.Tensor) -> torch.Tensor: + """Create a tensor like `like`, returning an empty one when no values exist.""" + return ( + smee.utils.tensor_like(values, like) + if values else smee.utils.tensor_like([], like) + ) + + +def _validate_trainable_keys( + cols: list[str], + scales: dict[str, float], + limits: dict[str, tuple[float | None, float | None]], + regularize: dict[str, float], +) -> None: + """Ensure transform dictionaries only reference trainable columns.""" + if any(key not in cols for key in scales): + raise ValueError("cannot scale non-trainable parameters") + if any(key not in cols for key in limits): + raise ValueError("cannot clamp non-trainable parameters") + if any(key not in cols for key in regularize): + raise ValueError("cannot regularize non-trainable parameters") + + if pydantic.__version__.startswith("1."): _PotentialKey = openff.interchange.models.PotentialKey PotentialKeyList = list[_PotentialKey] @@ -62,7 +85,7 @@ def _convert_keys(value: typing.Any) -> typing.Any: if not isinstance(value, list): return value - value = [ + return [ ( _PotentialKey(**v.dict()) if isinstance(v, openff.interchange.models.PotentialKey) @@ -70,7 +93,6 @@ def _convert_keys(value: typing.Any) -> typing.Any: ) for v in value ] - return value PotentialKeyList = typing.Annotated[ list[_PotentialKey], pydantic.BeforeValidator(_convert_keys) @@ -112,12 +134,7 @@ def _validate_keys(cls, values): limits = values.get("limits") regularize = values.get("regularize") - if any(key not in cols for key in scales): - raise ValueError("cannot scale non-trainable parameters") - if any(key not in cols for key in limits): - raise ValueError("cannot clamp non-trainable parameters") - if any(key not in cols for key in regularize): - raise ValueError("cannot regularize non-trainable parameters") + _validate_trainable_keys(cols, scales, limits, regularize) return values @@ -127,14 +144,9 @@ def _validate_keys(cls, values): def _validate_keys(self): """Ensure that the keys in `scales` and `limits` match `cols`.""" - if any(key not in self.cols for key in self.scales): - raise ValueError("cannot scale non-trainable parameters") - - if any(key not in self.cols for key in self.limits): - raise ValueError("cannot clamp non-trainable parameters") - - if any(key not in self.cols for key in self.regularize): - raise ValueError("cannot regularize non-trainable parameters") + _validate_trainable_keys( + self.cols, self.scales, self.limits, self.regularize + ) return self @@ -187,6 +199,25 @@ def _validate_include_exclude(self): return self +class _PreparedBlock(typing.NamedTuple): + """Per-block output of ``_prepare`` and ``_prepare_vsites``. + + All index tensors are zero-based relative to this block's own flat value + tensor. ``__init__`` is responsible for applying global offsets when + combining blocks. + """ + + values: torch.Tensor + shapes: list[torch.Size] + unfrozen_idxs: torch.Tensor + scales: torch.Tensor + clamp_lower: torch.Tensor + clamp_upper: torch.Tensor + # Always a subset of unfrozen_idxs. + regularized_idxs: torch.Tensor + regularization_weights: torch.Tensor + + class Trainable: """A convenient wrapper around a tensor force field that gives greater control over how parameters should be trained. @@ -196,15 +227,96 @@ class Trainable: parameters so they are not updated during training. """ + @staticmethod + def _prepare_rows( + config: AttributeConfig, + n_rows: int, + all_keys: list[_PotentialKey] | None = None, + ) -> set[int]: + """Determine which rows should be trainable.""" + if not isinstance(config, ParameterConfig): + return set(range(n_rows)) + + assert all_keys is not None + + excluded_keys = config.exclude or [] + unfrozen_keys = config.include or all_keys + + key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} + assert len(key_to_row) == len(all_keys), "duplicate keys found" + + return {key_to_row[key] for key in unfrozen_keys if key not in excluded_keys} + + @staticmethod + def _prepare_values( + values: torch.Tensor, + cols: list[str] | tuple[str, ...], + config: AttributeConfig, + unfrozen_rows: set[int], + ) -> tuple[ + list[int], + list[float], + list[float], + list[float], + list[int], + list[float], + ]: + """Prepare unfrozen indices and transform metadata for a value block. + + Returned indices are zero-based relative to this block's flat value + tensor. The caller is responsible for applying any global offset before + combining indices across blocks. + + Only unfrozen entries are accumulated, so all returned lists are + parallel and require no further post-hoc indexing. + """ + row_width = values.shape[-1] + col_to_idx = {col: idx for idx, col in enumerate(cols)} + + unfrozen_idxs = [] + scales = [] + clamp_lower = [] + clamp_upper = [] + regularized_idxs = [] + regularization_weights = [] + + for row_idx in unfrozen_rows: + for col in config.cols: + flat_idx = col_to_idx[col] + row_idx * row_width + + unfrozen_idxs.append(flat_idx) + scales.append(config.scales.get(col, 1.0)) + + lower, upper = config.limits.get(col, (None, None)) + clamp_lower.append(-torch.inf if lower is None else lower) + clamp_upper.append(torch.inf if upper is None else upper) + + if col in config.regularize: + regularized_idxs.append(flat_idx) + regularization_weights.append(config.regularize[col]) + + return ( + unfrozen_idxs, + scales, + clamp_lower, + clamp_upper, + regularized_idxs, + regularization_weights, + ) + @classmethod def _prepare( cls, force_field: smee.TensorForceField, config: dict[str, AttributeConfig], attr: typing.Literal["parameters", "attributes"], - ): + ) -> _PreparedBlock: """Prepare the trainable parameters or attributes for the given force field and - configuration.""" + configuration. + + Returned indices are zero-based relative to the block's own flat value + tensor. + """ potential_types = sorted(config) potentials = [ force_field.potentials_by_type[potential_type] @@ -215,10 +327,9 @@ def _prepare( shapes = [] unfrozen_idxs = [] - unfrozen_col_offset = 0 + idx_offset = 0 scales = [] - clamp_lower = [] clamp_upper = [] @@ -241,88 +352,97 @@ def _prepare( n_rows = 1 if attr == "attributes" else len(potential_values) - unfrozen_rows = set(range(n_rows)) - + all_keys = None if isinstance(potential_config, ParameterConfig): all_keys = [ _PotentialKey(**v.dict()) for v in getattr(potential, f"{attr[:-1]}_keys") ] - excluded_keys = potential_config.exclude or [] - unfrozen_keys = potential_config.include or all_keys - - key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} - assert len(key_to_row) == len(all_keys), "duplicate keys found" - - unfrozen_rows = { - key_to_row[key] for key in unfrozen_keys if key not in excluded_keys - } - - # Track unfrozen and regularized indices - for row_idx in range(n_rows): - if row_idx not in unfrozen_rows: - continue - for col_idx, col in enumerate(potential_cols): - if col not in potential_config.cols: - continue - - flat_idx = ( - unfrozen_col_offset - + col_idx - + row_idx * potential_values.shape[-1] - ) - unfrozen_idxs.append(flat_idx) - - if col in potential_config.regularize: - regularized_idxs.append(flat_idx) - regularization_weights.append(potential_config.regularize[col]) - - unfrozen_col_offset += len(potential_values_flat) - - potential_scales = [ - potential_config.scales.get(col, 1.0) for col in potential_cols - ] * n_rows + unfrozen_rows = cls._prepare_rows(potential_config, n_rows, all_keys) + ( + potential_unfrozen_idxs, + potential_scales, + potential_clamp_lower, + potential_clamp_upper, + potential_regularized_idxs, + potential_regularization_weights, + ) = cls._prepare_values( + potential_values, + potential_cols, + potential_config, + unfrozen_rows, + ) + unfrozen_idxs.extend(i + idx_offset for i in potential_unfrozen_idxs) scales.extend(potential_scales) - - potential_clamp_lower = [ - potential_config.limits.get(col, (None, None))[0] - for col in potential_cols - ] * n_rows - potential_clamp_lower = [ - -torch.inf if x is None else x for x in potential_clamp_lower - ] clamp_lower.extend(potential_clamp_lower) - - potential_clamp_upper = [ - potential_config.limits.get(col, (None, None))[1] - for col in potential_cols - ] * n_rows - potential_clamp_upper = [ - torch.inf if x is None else x for x in potential_clamp_upper - ] clamp_upper.extend(potential_clamp_upper) + regularized_idxs.extend(i + idx_offset for i in potential_regularized_idxs) + regularization_weights.extend(potential_regularization_weights) - values = ( + idx_offset += len(potential_values_flat) + + flat_values = ( smee.utils.tensor_like([], force_field.potentials[0].parameters) if len(values) == 0 else torch.cat(values) ) - return ( - potential_types, - values, - shapes, - torch.tensor(unfrozen_idxs), - smee.utils.tensor_like(scales, values), - smee.utils.tensor_like(clamp_lower, values), - smee.utils.tensor_like(clamp_upper, values), - torch.tensor(regularized_idxs), - ( - smee.utils.tensor_like(regularization_weights, values) - if regularization_weights - else smee.utils.tensor_like([], values) + return _PreparedBlock( + values=flat_values, + shapes=shapes, + unfrozen_idxs=torch.tensor(unfrozen_idxs), + scales=smee.utils.tensor_like(scales, flat_values), + clamp_lower=smee.utils.tensor_like(clamp_lower, flat_values), + clamp_upper=smee.utils.tensor_like(clamp_upper, flat_values), + regularized_idxs=torch.tensor(regularized_idxs), + regularization_weights=_tensor_like_or_empty( + regularization_weights, flat_values + ), + ) + + def _prepare_vsites( + self, force_field: smee.TensorForceField, config: ParameterConfig + ) -> _PreparedBlock: + """Prepare the vsite parameters for optimisation. + + Returned indices are zero-based relative to the block's own flat value + tensor. + """ + vsite_parameters = force_field.v_sites.parameters.detach().clone() + n_rows = vsite_parameters.shape[0] + vsite_parameters_flat = vsite_parameters.flatten() + # Getting the column names from this dict is a bit awkward but + # avoids hard-coding them. + vsite_cols = list(force_field.v_sites.default_units().keys()) + + all_keys = [_PotentialKey(**key.dict()) for key in force_field.v_sites.keys] + unfrozen_rows = self._prepare_rows(config, n_rows, all_keys) + ( + unfrozen_idxs, + vsite_scales, + clamp_lower, + clamp_upper, + regularized_idxs, + regularization_weights, + ) = self._prepare_values( + vsite_parameters, + vsite_cols, + config, + unfrozen_rows, + ) + + return _PreparedBlock( + values=vsite_parameters_flat, + shapes=[vsite_parameters.shape], + unfrozen_idxs=torch.tensor(unfrozen_idxs), + scales=smee.utils.tensor_like(vsite_scales, vsite_parameters), + clamp_lower=smee.utils.tensor_like(clamp_lower, vsite_parameters), + clamp_upper=smee.utils.tensor_like(clamp_upper, vsite_parameters), + regularized_idxs=torch.tensor(regularized_idxs), + regularization_weights=_tensor_like_or_empty( + regularization_weights, vsite_parameters ), ) @@ -331,6 +451,7 @@ def __init__( force_field: smee.TensorForceField, parameters: dict[str, ParameterConfig], attributes: dict[str, AttributeConfig], + vsites: ParameterConfig | None = None, ): """ @@ -338,72 +459,67 @@ def __init__( force_field: The force field to wrap. parameters: Configure which parameters to train. attributes: Configure which attributes to train. + vsites: Configure which vsite parameters to train. """ self._force_field = force_field + self._fit_vsites = False + + param_block = self._prepare(force_field, parameters, "parameters") + attr_block = self._prepare(force_field, attributes, "attributes") + + self._param_types = sorted(parameters) + self._param_shapes = param_block.shapes + self._attr_types = sorted(attributes) + self._attr_shapes = attr_block.shapes + + blocks: list[_PreparedBlock] = [param_block, attr_block] + + if vsites is not None: + blocks.append(self._prepare_vsites(force_field, vsites)) + self._fit_vsites = True + + # Each block's indices are zero-based within that block. Apply the + # running value-tensor offset at the point of combination so that the + # offset arithmetic is in one place and adding further blocks requires + # no changes beyond appending to `blocks`. + all_values = [] + all_unfrozen_idxs = [] + all_scales = [] + all_clamp_lower = [] + all_clamp_upper = [] + all_regularized_idxs = [] + all_regularization_weights = [] + + offset = 0 + for block in blocks: + all_values.append(block.values) + all_unfrozen_idxs.append(block.unfrozen_idxs + offset) + all_scales.append(block.scales) + all_clamp_lower.append(block.clamp_lower) + all_clamp_upper.append(block.clamp_upper) + all_regularized_idxs.append(block.regularized_idxs + offset) + all_regularization_weights.append(block.regularization_weights) + offset += len(block.values) + + self._values = torch.cat(all_values) + self._unfrozen_idxs = torch.cat(all_unfrozen_idxs).long() + self._scales = torch.cat(all_scales) + self._clamp_lower = torch.cat(all_clamp_lower) + self._clamp_upper = torch.cat(all_clamp_upper) + + # Map global flat indices -> positions within the unfrozen vector. + # regularized_idxs are guaranteed to be a subset of unfrozen_idxs, + # so no missing-key guard is needed. + combined_regularized_idxs = torch.cat(all_regularized_idxs).long() + combined_regularization_weights = torch.cat(all_regularization_weights) - ( - self._param_types, - param_values, - self._param_shapes, - param_unfrozen_idxs, - param_scales, - param_clamp_lower, - param_clamp_upper, - param_regularized_idxs, - param_regularization_weights, - ) = self._prepare(force_field, parameters, "parameters") - ( - self._attr_types, - attr_values, - self._attr_shapes, - attr_unfrozen_idxs, - attr_scales, - attr_clamp_lower, - attr_clamp_upper, - attr_regularized_idxs, - attr_regularization_weights, - ) = self._prepare(force_field, attributes, "attributes") - - self._values = torch.cat([param_values, attr_values]) - - self._unfrozen_idxs = torch.cat( - [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)] - ).long() - - self._scales = torch.cat([param_scales, attr_scales])[self._unfrozen_idxs] - - self._clamp_lower = torch.cat([param_clamp_lower, attr_clamp_lower])[ - self._unfrozen_idxs - ] - self._clamp_upper = torch.cat([param_clamp_upper, attr_clamp_upper])[ - self._unfrozen_idxs - ] - - # Store regularization information - all_regularized_idxs = torch.cat( - [param_regularized_idxs, attr_regularized_idxs + len(param_scales)] - ).long() - all_regularization_weights = torch.cat( - [param_regularization_weights, attr_regularization_weights] - ) - - # Map global indices to unfrozen indices idx_mapping = {idx.item(): i for i, idx in enumerate(self._unfrozen_idxs)} self._regularized_idxs = torch.tensor( - [ - idx_mapping[idx.item()] - for idx in all_regularized_idxs - if idx.item() in idx_mapping - ] + [idx_mapping[idx.item()] for idx in combined_regularized_idxs] ).long() - regularization_weights = [ - all_regularization_weights[i] - for i, idx in enumerate(all_regularized_idxs) - if idx.item() in idx_mapping - ] self._regularization_weights = ( - torch.stack(regularization_weights) - if regularization_weights + combined_regularization_weights + if len(combined_regularization_weights) > 0 else torch.tensor([]) ) @@ -427,18 +543,29 @@ def to_force_field(self, values_flat: torch.Tensor) -> smee.TensorForceField: values[self._unfrozen_idxs] = (values_flat / self._scales).clamp( min=self._clamp_lower, max=self._clamp_upper ) - values = _unflatten_tensors(values, self._param_shapes + self._attr_shapes) + shapes = self._param_shapes + self._attr_shapes + + if self._fit_vsites: + shapes.append(self._force_field.v_sites.parameters.shape) + + values = _unflatten_tensors(values, shapes) params = values[: len(self._param_shapes)] for potential_type, param in zip(self._param_types, params, strict=True): potentials[potential_type].parameters = param - attrs = values[len(self._param_shapes) :] + attrs = values[ + len(self._param_shapes) : len(self._param_shapes) + len(self._attr_shapes) + ] for potential_type, attr in zip(self._attr_types, attrs, strict=True): potentials[potential_type].attributes = attr + if self._fit_vsites: + vsite_params = values[len(self._param_shapes) + len(self._attr_shapes) :] + self._force_field.v_sites.parameters = vsite_params[0] + return self._force_field @torch.no_grad()