diff --git a/smee/potentials/nonbonded.py b/smee/potentials/nonbonded.py index ced5646..d71aee5 100644 --- a/smee/potentials/nonbonded.py +++ b/smee/potentials/nonbonded.py @@ -352,10 +352,9 @@ def _integrate_lj_switch( ) coeff_11 = smee.utils.tensor_like([84, -3780, 7560, 2520, -3780, 756], rs) * coeff_0 - r_pow = torch.pow( - r, smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs) - ) - r_pow[-3] = torch.log(r) + powers = smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs) + r_pow = torch.pow(r, powers) + r_pow = torch.where(powers == 0, torch.log(r), r_pow) integral = ( -(b**3) diff --git a/smee/tests/convertors/openff/test_nonbonded.py b/smee/tests/convertors/openff/test_nonbonded.py index 2a00bbe..5220546 100644 --- a/smee/tests/convertors/openff/test_nonbonded.py +++ b/smee/tests/convertors/openff/test_nonbonded.py @@ -1,11 +1,10 @@ import importlib.util -import pytest - import openff.interchange import openff.interchange.models import openff.toolkit import openff.units +import pytest import torch import smee diff --git a/smee/tests/convertors/openff/test_openff.py b/smee/tests/convertors/openff/test_openff.py index 955e08f..53c29be 100644 --- a/smee/tests/convertors/openff/test_openff.py +++ b/smee/tests/convertors/openff/test_openff.py @@ -2,10 +2,10 @@ import openff.interchange.models import openff.toolkit -from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager import openff.units import pytest import torch +from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager import smee import smee.tests.utils diff --git a/smee/tests/potentials/test_nonbonded.py b/smee/tests/potentials/test_nonbonded.py index a09d6a5..b73872e 100644 --- a/smee/tests/potentials/test_nonbonded.py +++ b/smee/tests/potentials/test_nonbonded.py @@ -16,8 +16,8 @@ _COULOMB_PRE_FACTOR, _compute_dexp_lrc, _compute_lj_lrc, - _compute_pme_exclusions, _compute_pairwise_periodic, + _compute_pme_exclusions, compute_coulomb_energy, compute_dexp_energy, compute_lj_energy, diff --git a/smee/tests/potentials/test_potentials.py b/smee/tests/potentials/test_potentials.py index 31a6f0e..08c4af0 100644 --- a/smee/tests/potentials/test_potentials.py +++ b/smee/tests/potentials/test_potentials.py @@ -232,3 +232,32 @@ def test_compute_energy_v_sites(): energy_smee = compute_energy(tensor_top, tensor_ff, conformer) assert torch.isclose(energy_smee, energy_openmm.to(energy_smee.dtype)) + + +@pytest.mark.parametrize("periodic", [True, False]) +def test_energy_backward_pass(periodic): + tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(["CCC", "O"], [2, 3]) + tensor_sys.is_periodic = periodic + + coords, _ = smee.mm.generate_system_coords(tensor_sys, None) + coords = torch.tensor(coords.value_in_unit(openmm.unit.angstrom)) + box_vectors = torch.eye(3, dtype=coords.dtype) * 30.0 if periodic else None + coords.requires_grad = True + + for potential in tensor_ff.potentials: + potential.parameters.requires_grad = True + if potential.attributes is not None: + potential.attributes.requires_grad = True + + energy = smee.compute_energy(tensor_sys, tensor_ff, coords, box_vectors) + energy.backward() + + assert coords.grad is not None + for potential in tensor_ff.potentials: + assert potential.parameters.grad is not None + assert potential.parameters.grad.shape == potential.parameters.shape + assert not torch.isnan(potential.parameters.grad).any() + if potential.attributes is not None: + assert potential.attributes.grad is not None + assert potential.attributes.grad.shape == potential.attributes.shape + assert not torch.isnan(potential.attributes.grad).any()