From b13d544634dd816fa3cab837b3f8a81c96d3e16e Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Fri, 6 Mar 2026 10:31:49 +0000 Subject: [PATCH 1/3] Update _compute_pairwise_periodic to allow double backward of the energies. --- smee/potentials/nonbonded.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py index d71aee5..8604879 100644 --- a/smee/potentials/nonbonded.py +++ b/smee/potentials/nonbonded.py @@ -137,21 +137,22 @@ def _compute_pairwise_periodic( ( pair_idxs, - deltas, + _, distances, _, ) = NNPOps.neighbors.getNeighborPairs(conformer, cutoff.item(), -1, box_vectors) are_interacting = ~torch.isnan(distances) - - distances = distances[are_interacting] - deltas = deltas[are_interacting, :] pair_idxs = pair_idxs[:, are_interacting] - # we sort the indices to get values correponding to upper triangles - # but we need to track which have been reversed so we can reverse the deltas - pair_idxs, indices = pair_idxs.sort(dim=0) - reversed = -(indices[0] == 1).to(deltas.dtype) - deltas = deltas * reversed[:, None] + + # ensure i < j + pair_idxs, _ = pair_idxs.sort(dim=0) + + deltas = conformer[pair_idxs[0]] - conformer[pair_idxs[1]] + shifts = torch.round(deltas @ torch.linalg.inv(box_vectors)) + deltas = deltas - shifts @ box_vectors + + distances = torch.linalg.norm(deltas, dim=-1) return PairwiseDistances(pair_idxs.T.contiguous(), deltas, distances, cutoff) From 05ad9c46d66910e674b8a4f762b58b3402eb2b7e Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Fri, 6 Mar 2026 10:41:57 +0000 Subject: [PATCH 2/3] Add comment. --- smee/potentials/nonbonded.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py index 8604879..9ba1476 100644 --- a/smee/potentials/nonbonded.py +++ b/smee/potentials/nonbonded.py @@ -148,6 +148,9 @@ def _compute_pairwise_periodic( # ensure i < j pair_idxs, _ = pair_idxs.sort(dim=0) + # we recompute the distances because the getNeighborPairs of NNPOps does not + # support double backward gradients, and we need the distances to be differentiable + # for computing, e.g., derivatives of the forces w.r.t. the FF parameters. deltas = conformer[pair_idxs[0]] - conformer[pair_idxs[1]] shifts = torch.round(deltas @ torch.linalg.inv(box_vectors)) deltas = deltas - shifts @ box_vectors From cf0fd06692b031cb3cd328a2071910e623ebab7f Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 12 Mar 2026 16:03:30 +0000 Subject: [PATCH 3/3] Add test for double backward in compute_pairwise_periodic --- smee/tests/potentials/test_nonbonded.py | 35 +++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/smee/tests/potentials/test_nonbonded.py b/smee/tests/potentials/test_nonbonded.py index b73872e..5ed147c 100644 --- a/smee/tests/potentials/test_nonbonded.py +++ b/smee/tests/potentials/test_nonbonded.py @@ -569,3 +569,38 @@ def test_compute_pairwise_periodic_indices(): ) # note, indices end up sorted into the upper triangular matrix assert torch.all(pairwise_distances.idxs == torch.tensor([[0, 1], [0, 2]])) + + +def test_compute_pairwise_periodic_double_backward(): + epsilon = torch.tensor([0.3], requires_grad=True) + sigma = torch.tensor([3.0], requires_grad=True) + cutoff = torch.tensor(9.0) + + # Two particles + coords = torch.tensor([[0.0, 0.0, 0.0], [3.1, 0.0, 0.0]], requires_grad=True) + box_vectors = torch.eye(3) * 30.0 + pairwise = _compute_pairwise_periodic(coords, box_vectors, cutoff) + assert pairwise.distances.grad_fn is not None + + # Get LJ energy for the pair + sig_r = sigma / pairwise.distances + energy = 4.0 * epsilon * (sig_r**12 - sig_r**6) + assert energy.grad_fn is not None + + # First backward, compute forces (F = -dE/dcoords) + forces = -torch.autograd.grad(energy, coords, create_graph=True, retain_graph=True)[ + 0 + ] + force_loss = (forces**2).sum() + + # Second backward, get parameter gradients + force_loss.backward() + epsilon_grad = epsilon.grad + sigma_grad = sigma.grad + + assert torch.isclose( + epsilon_grad, torch.tensor([20.0522], dtype=epsilon_grad.dtype), atol=1.0e-5 + ) + assert torch.isclose( + sigma_grad, torch.tensor([42.7793], dtype=sigma_grad.dtype), atol=1.0e-5 + )