diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py index d71aee5..9ba1476 100644 --- a/smee/potentials/nonbonded.py +++ b/smee/potentials/nonbonded.py @@ -137,21 +137,25 @@ def _compute_pairwise_periodic( ( pair_idxs, - deltas, + _, distances, _, ) = NNPOps.neighbors.getNeighborPairs(conformer, cutoff.item(), -1, box_vectors) are_interacting = ~torch.isnan(distances) - - distances = distances[are_interacting] - deltas = deltas[are_interacting, :] pair_idxs = pair_idxs[:, are_interacting] - # we sort the indices to get values correponding to upper triangles - # but we need to track which have been reversed so we can reverse the deltas - pair_idxs, indices = pair_idxs.sort(dim=0) - reversed = -(indices[0] == 1).to(deltas.dtype) - deltas = deltas * reversed[:, None] + + # ensure i < j + pair_idxs, _ = pair_idxs.sort(dim=0) + + # we recompute the distances because the getNeighborPairs of NNPOps does not + # support double backward gradients, and we need the distances to be differentiable + # for computing, e.g., derivatives of the forces w.r.t. the FF parameters. + deltas = conformer[pair_idxs[0]] - conformer[pair_idxs[1]] + shifts = torch.round(deltas @ torch.linalg.inv(box_vectors)) + deltas = deltas - shifts @ box_vectors + + distances = torch.linalg.norm(deltas, dim=-1) return PairwiseDistances(pair_idxs.T.contiguous(), deltas, distances, cutoff) diff --git a/smee/tests/potentials/test_nonbonded.py b/smee/tests/potentials/test_nonbonded.py index b73872e..5ed147c 100644 --- a/smee/tests/potentials/test_nonbonded.py +++ b/smee/tests/potentials/test_nonbonded.py @@ -569,3 +569,38 @@ def test_compute_pairwise_periodic_indices(): ) # note, indices end up sorted into the upper triangular matrix assert torch.all(pairwise_distances.idxs == torch.tensor([[0, 1], [0, 2]])) + + +def test_compute_pairwise_periodic_double_backward(): + epsilon = torch.tensor([0.3], requires_grad=True) + sigma = torch.tensor([3.0], requires_grad=True) + cutoff = torch.tensor(9.0) + + # Two particles + coords = torch.tensor([[0.0, 0.0, 0.0], [3.1, 0.0, 0.0]], requires_grad=True) + box_vectors = torch.eye(3) * 30.0 + pairwise = _compute_pairwise_periodic(coords, box_vectors, cutoff) + assert pairwise.distances.grad_fn is not None + + # Get LJ energy for the pair + sig_r = sigma / pairwise.distances + energy = 4.0 * epsilon * (sig_r**12 - sig_r**6) + assert energy.grad_fn is not None + + # First backward, compute forces (F = -dE/dcoords) + forces = -torch.autograd.grad(energy, coords, create_graph=True, retain_graph=True)[ + 0 + ] + force_loss = (forces**2).sum() + + # Second backward, get parameter gradients + force_loss.backward() + epsilon_grad = epsilon.grad + sigma_grad = sigma.grad + + assert torch.isclose( + epsilon_grad, torch.tensor([20.0522], dtype=epsilon_grad.dtype), atol=1.0e-5 + ) + assert torch.isclose( + sigma_grad, torch.tensor([42.7793], dtype=sigma_grad.dtype), atol=1.0e-5 + )