Conversation
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. Thanks for integrating Codecov - We've got you covered ☂️ |
|
Thanks for opening this PR @jthorton, are you still keen on having it in? If so do you think it would be straightforward to refactor the vsites so they behave similarly to the other potentials in |
|
Hi @lilyminium yeah, this would be great to have and probably good for @fjclark and @JMorado as well.
Its been a while since I looked at it, but from what I remeber the issue was that some info is missing from the tensor vsite potential like the distance and angles, so this was needed as a workaround, but it might be possible to fix this. |
|
Thanks both, yes this would be great to have for the long run! I'll try to take a look this week. |
| openff.toolkit.ForceField("tip4p_fb.offxml"), | ||
| openff.toolkit.Molecule.from_smiles("O").to_topology(), | ||
| ) | ||
| ff, _ = smee.converters.convert_interchange(interachange) |
| n_rows = vsite_parameters.shape[0] | ||
| vsite_parameters_flat = vsite_parameters.flatten() | ||
| # define the cols as they are not on the tensor model | ||
| vsite_cols = ["distance", "inPlaneAngle", "outOfPlaneAngle"] |
There was a problem hiding this comment.
Would it be any better to get these from e.g. list(tff.v_sites.default_units().keys())? Avoids hard-coding here but is maybe less clear.
| key_to_row = {key: row_idx for row_idx, key in enumerate(all_keys)} | ||
| assert len(key_to_row) == len(all_keys), "duplicate keys found" | ||
|
|
||
| unfrozen_rows = { | ||
| key_to_row[key] for key in unfrozen_keys if key not in excluded_keys | ||
| } | ||
|
|
||
| unfrozen_idxs = [ | ||
| col_idx + row_idx * vsite_parameters.shape[1] | ||
| for row_idx in range(n_rows) | ||
| if row_idx in unfrozen_rows | ||
| # the vsite model has no parameter cols so define here | ||
| for col_idx, col in enumerate(vsite_cols) | ||
| if col in config.cols | ||
| ] | ||
| vsite_scales = [config.scales.get(col, 1.0) for col in vsite_cols] * n_rows | ||
| clamp_lower = [ | ||
| config.limits.get(col, (None, None))[0] for col in vsite_cols | ||
| ] * n_rows | ||
| clamp_lower = [-torch.inf if x is None else x for x in clamp_lower] | ||
| clamp_upper = [ | ||
| config.limits.get(col, (None, None))[1] for col in vsite_cols | ||
| ] * n_rows | ||
| clamp_upper = [torch.inf if x is None else x for x in clamp_upper] |
There was a problem hiding this comment.
Maybe we could refactor a bit to avoid duplication with the standard _prepare fn?
|
Looks good to me, though I'll need to add regularisation support consistent with the normal parameters. I'll create a separate PR based off this to do that. |
|
Sorry I wasn't clear -- #89 includes the suggested changes. I'll try to address the review comments later this week. |
Description
This PR extends the
Trainableclass to also work with vsites making it much easier to fit the geometric parameters of the sites such as the distance from the parent atom.Status