Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions smee/potentials/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,9 @@ def _integrate_lj_switch(
)
coeff_11 = smee.utils.tensor_like([84, -3780, 7560, 2520, -3780, 756], rs) * coeff_0

r_pow = torch.pow(
r, smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs)
)
r_pow[-3] = torch.log(r)
powers = smee.utils.tensor_like([-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], rs)
r_pow = torch.pow(r, powers)
r_pow = torch.where(powers == 0, torch.log(r), r_pow)

integral = (
-(b**3)
Expand Down
3 changes: 1 addition & 2 deletions smee/tests/convertors/openff/test_nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import importlib.util

import pytest

import openff.interchange
import openff.interchange.models
import openff.toolkit
import openff.units
import pytest
import torch

import smee
Expand Down
2 changes: 1 addition & 1 deletion smee/tests/convertors/openff/test_openff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import openff.interchange.models
import openff.toolkit
from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager
import openff.units
import pytest
import torch
from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager

import smee
import smee.tests.utils
Expand Down
2 changes: 1 addition & 1 deletion smee/tests/potentials/test_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
_COULOMB_PRE_FACTOR,
_compute_dexp_lrc,
_compute_lj_lrc,
_compute_pme_exclusions,
_compute_pairwise_periodic,
_compute_pme_exclusions,
compute_coulomb_energy,
compute_dexp_energy,
compute_lj_energy,
Expand Down
29 changes: 29 additions & 0 deletions smee/tests/potentials/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,32 @@ def test_compute_energy_v_sites():
energy_smee = compute_energy(tensor_top, tensor_ff, conformer)

assert torch.isclose(energy_smee, energy_openmm.to(energy_smee.dtype))


@pytest.mark.parametrize("periodic", [True, False])
def test_energy_backward_pass(periodic):
tensor_sys, tensor_ff = smee.tests.utils.system_from_smiles(["CCC", "O"], [2, 3])
tensor_sys.is_periodic = periodic

coords, _ = smee.mm.generate_system_coords(tensor_sys, None)
coords = torch.tensor(coords.value_in_unit(openmm.unit.angstrom))
box_vectors = torch.eye(3, dtype=coords.dtype) * 30.0 if periodic else None
coords.requires_grad = True

for potential in tensor_ff.potentials:
potential.parameters.requires_grad = True
if potential.attributes is not None:
potential.attributes.requires_grad = True

energy = smee.compute_energy(tensor_sys, tensor_ff, coords, box_vectors)
energy.backward()

assert coords.grad is not None
for potential in tensor_ff.potentials:
assert potential.parameters.grad is not None
assert potential.parameters.grad.shape == potential.parameters.shape
assert not torch.isnan(potential.parameters.grad).any()
if potential.attributes is not None:
assert potential.attributes.grad is not None
assert potential.attributes.grad.shape == potential.attributes.shape
assert not torch.isnan(potential.attributes.grad).any()
Loading