Skip to content

In-place operations breaking gradient flow for attributes #148

@JMorado

Description

@JMorado

Hi,

I came across what appears to be a bug whereby in-place operations are breaking the gradient flow for attributes of the LJ potential. This is particularly relevant for those optimising vdW attributes (e.g. 1–4 fudge factors) in periodic systems. The snippet below should reproduce the issue:

import torch
import smee
import smee.converters
import openff.toolkit
import openff.interchange

torch.autograd.set_detect_anomaly(True)

# Create a water molecule and ff
mol = openff.toolkit.Molecule.from_smiles("O")
ff = openff.toolkit.ForceField("openff-2.0.0.offxml")

# Generate positions and box
mol.generate_conformers(n_conformers=1)
pos1 = torch.tensor(mol.conformers[0].magnitude)
pos2 = torch.tensor(mol.conformers[0].magnitude + 5.0)
coords = torch.vstack([pos1, pos2])
box = torch.eye(3) * 50.0

# Convert to smee system
interchange = openff.interchange.Interchange.from_smirnoff(ff, mol.to_topology())
tensor_ff, tensor_top = smee.converters.convert_interchange([interchange])
system = smee.TensorSystem(tensor_top, [2], is_periodic=True)

# Gradients on
vdw_potential = tensor_ff.potentials_by_type["vdW"]
vdw_potential.parameters.requires_grad = True
vdw_potential.attributes.requires_grad = True

# Get energy
energy = smee.compute_energy(system, tensor_ff, coords, box)

# Attempt backward pass
energy.backward()

In particular, the in-place operation causing the issue is this line:
https://github.com/openforcefield/smee/blob/main/smee/potentials/nonbonded.py#L358

The fix is pretty simple and the following patch should resolve the issue (I'd be happy to open a PR if helpful):

diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py
index ced5646..d71aee5 100644
--- a/smee/potentials/nonbonded.py
+++ b/smee/potentials/nonbonded.py
@@ -352,10 +352,9 @@ def _integrate_lj_switch(
     )
     coeff_11 = smee.utils.tensor_like([84, -3780, 7560, 2520, -3780, 756], rs) * coeff_0

-    r_pow = torch.pow(
-        r, smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs)
-    )
-    r_pow[-3] = torch.log(r)
+    powers = smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs)
+    r_pow = torch.pow(r, powers)
+    r_pow = torch.where(powers == 0, torch.log(r), r_pow)

     integral = (
         -(b**3)

I think this is especially relevant for those use descent because, even if an empty attribute configuration is passed to Trainable, the attributes will still require gradients, which triggers this issue. This is actually how I came across it:

import torch
import smee
import smee.converters
import openff.toolkit
import openff.interchange
from descent.train import Trainable, ParameterConfig, AttributeConfig


mol = openff.toolkit.Molecule.from_smiles("O")
ff = openff.toolkit.ForceField("openff-2.0.0.offxml")
interchange = openff.interchange.Interchange.from_smirnoff(ff, mol.to_topology())
tensor_ff, topologies = smee.converters.convert_interchange(interchange)

# Create a Trainable with empty attribute optimisation
parameters = {"vdW": ParameterConfig(cols=["epsilon", "sigma"])}
vdw_attribute_config = AttributeConfig(
    cols=[],
    scales={},
    limits={},
)
attributes = {"vdW": vdw_attribute_config}

# Create trainable and reconstruct FF
trainable = Trainable(tensor_ff, parameters, attributes)
ff_train = trainable.to_force_field(values)

# Check that attributes in the new FF have requires_grad=True
vdw_attr = ff_train.potentials_by_type["vdW"].attributes
print(f"Attributes requires_grad: {vdw_attr.requires_grad}")

Please let me know if more info from my side.
Many thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions