Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions smee/potentials/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,25 @@ def _compute_pairwise_periodic(

(
pair_idxs,
deltas,
_,
distances,
_,
) = NNPOps.neighbors.getNeighborPairs(conformer, cutoff.item(), -1, box_vectors)

are_interacting = ~torch.isnan(distances)

distances = distances[are_interacting]
deltas = deltas[are_interacting, :]
pair_idxs = pair_idxs[:, are_interacting]
# we sort the indices to get values correponding to upper triangles
# but we need to track which have been reversed so we can reverse the deltas
pair_idxs, indices = pair_idxs.sort(dim=0)
reversed = -(indices[0] == 1).to(deltas.dtype)
deltas = deltas * reversed[:, None]

# ensure i < j
pair_idxs, _ = pair_idxs.sort(dim=0)

# we recompute the distances because the getNeighborPairs of NNPOps does not
# support double backward gradients, and we need the distances to be differentiable
# for computing, e.g., derivatives of the forces w.r.t. the FF parameters.
deltas = conformer[pair_idxs[0]] - conformer[pair_idxs[1]]
shifts = torch.round(deltas @ torch.linalg.inv(box_vectors))
deltas = deltas - shifts @ box_vectors

distances = torch.linalg.norm(deltas, dim=-1)

return PairwiseDistances(pair_idxs.T.contiguous(), deltas, distances, cutoff)

Expand Down
35 changes: 35 additions & 0 deletions smee/tests/potentials/test_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,38 @@ def test_compute_pairwise_periodic_indices():
)
# note, indices end up sorted into the upper triangular matrix
assert torch.all(pairwise_distances.idxs == torch.tensor([[0, 1], [0, 2]]))


def test_compute_pairwise_periodic_double_backward():
epsilon = torch.tensor([0.3], requires_grad=True)
sigma = torch.tensor([3.0], requires_grad=True)
cutoff = torch.tensor(9.0)

# Two particles
coords = torch.tensor([[0.0, 0.0, 0.0], [3.1, 0.0, 0.0]], requires_grad=True)
box_vectors = torch.eye(3) * 30.0
pairwise = _compute_pairwise_periodic(coords, box_vectors, cutoff)
assert pairwise.distances.grad_fn is not None

# Get LJ energy for the pair
sig_r = sigma / pairwise.distances
energy = 4.0 * epsilon * (sig_r**12 - sig_r**6)
assert energy.grad_fn is not None

# First backward, compute forces (F = -dE/dcoords)
forces = -torch.autograd.grad(energy, coords, create_graph=True, retain_graph=True)[
0
]
force_loss = (forces**2).sum()

# Second backward, get parameter gradients
force_loss.backward()
epsilon_grad = epsilon.grad
sigma_grad = sigma.grad

assert torch.isclose(
epsilon_grad, torch.tensor([20.0522], dtype=epsilon_grad.dtype), atol=1.0e-5
)
assert torch.isclose(
sigma_grad, torch.tensor([42.7793], dtype=sigma_grad.dtype), atol=1.0e-5
)
Loading