-
Notifications
You must be signed in to change notification settings - Fork 8
In-place operations breaking gradient flow for attributes #148
Description
Hi,
I came across what appears to be a bug whereby in-place operations are breaking the gradient flow for attributes of the LJ potential. This is particularly relevant for those optimising vdW attributes (e.g. 1–4 fudge factors) in periodic systems. The snippet below should reproduce the issue:
import torch
import smee
import smee.converters
import openff.toolkit
import openff.interchange
torch.autograd.set_detect_anomaly(True)
# Create a water molecule and ff
mol = openff.toolkit.Molecule.from_smiles("O")
ff = openff.toolkit.ForceField("openff-2.0.0.offxml")
# Generate positions and box
mol.generate_conformers(n_conformers=1)
pos1 = torch.tensor(mol.conformers[0].magnitude)
pos2 = torch.tensor(mol.conformers[0].magnitude + 5.0)
coords = torch.vstack([pos1, pos2])
box = torch.eye(3) * 50.0
# Convert to smee system
interchange = openff.interchange.Interchange.from_smirnoff(ff, mol.to_topology())
tensor_ff, tensor_top = smee.converters.convert_interchange([interchange])
system = smee.TensorSystem(tensor_top, [2], is_periodic=True)
# Gradients on
vdw_potential = tensor_ff.potentials_by_type["vdW"]
vdw_potential.parameters.requires_grad = True
vdw_potential.attributes.requires_grad = True
# Get energy
energy = smee.compute_energy(system, tensor_ff, coords, box)
# Attempt backward pass
energy.backward()In particular, the in-place operation causing the issue is this line:
https://github.com/openforcefield/smee/blob/main/smee/potentials/nonbonded.py#L358
The fix is pretty simple and the following patch should resolve the issue (I'd be happy to open a PR if helpful):
diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py
index ced5646..d71aee5 100644
--- a/smee/potentials/nonbonded.py
+++ b/smee/potentials/nonbonded.py
@@ -352,10 +352,9 @@ def _integrate_lj_switch(
)
coeff_11 = smee.utils.tensor_like([84, -3780, 7560, 2520, -3780, 756], rs) * coeff_0
- r_pow = torch.pow(
- r, smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs)
- )
- r_pow[-3] = torch.log(r)
+ powers = smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs)
+ r_pow = torch.pow(r, powers)
+ r_pow = torch.where(powers == 0, torch.log(r), r_pow)
integral = (
-(b**3)
I think this is especially relevant for those use descent because, even if an empty attribute configuration is passed to Trainable, the attributes will still require gradients, which triggers this issue. This is actually how I came across it:
import torch
import smee
import smee.converters
import openff.toolkit
import openff.interchange
from descent.train import Trainable, ParameterConfig, AttributeConfig
mol = openff.toolkit.Molecule.from_smiles("O")
ff = openff.toolkit.ForceField("openff-2.0.0.offxml")
interchange = openff.interchange.Interchange.from_smirnoff(ff, mol.to_topology())
tensor_ff, topologies = smee.converters.convert_interchange(interchange)
# Create a Trainable with empty attribute optimisation
parameters = {"vdW": ParameterConfig(cols=["epsilon", "sigma"])}
vdw_attribute_config = AttributeConfig(
cols=[],
scales={},
limits={},
)
attributes = {"vdW": vdw_attribute_config}
# Create trainable and reconstruct FF
trainable = Trainable(tensor_ff, parameters, attributes)
ff_train = trainable.to_force_field(values)
# Check that attributes in the new FF have requires_grad=True
vdw_attr = ff_train.potentials_by_type["vdW"].attributes
print(f"Attributes requires_grad: {vdw_attr.requires_grad}")Please let me know if more info from my side.
Many thanks in advance.