Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion descent/targets/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

DATA_SCHEMA = pyarrow.schema(
[
("id", pyarrow.string()),
("smiles", pyarrow.string()),
("coords", pyarrow.list_(pyarrow.float64())),
("box_vectors", pyarrow.list_(pyarrow.float64())),
("energy", pyarrow.list_(pyarrow.float64())),
("forces", pyarrow.list_(pyarrow.float64())),
]
Expand All @@ -22,6 +24,9 @@
class Entry(typing.TypedDict):
"""Represents a set of reference energies and forces."""

id: str | None
"""An optional identifier for the entry (e.g. a run name). Defaults to ``None``."""

smiles: str
"""The indexed SMILES description of the molecule the energies and forces were
computed for."""
Expand All @@ -34,6 +39,10 @@ class Entry(typing.TypedDict):
forces: torch.Tensor
"""The reference forces [kcal/mol/Å] with ``shape=(n_confs, n_particles, 3)``."""

box_vectors: torch.Tensor | None
"""The box vectors [Å] for periodic systems with ``shape=(n_confs, 3, 3)``, or
``None`` for non-periodic systems."""


def create_dataset(entries: list[Entry]) -> datasets.Dataset:
"""Create a dataset from a list of existing entries.
Expand All @@ -48,8 +57,12 @@ def create_dataset(entries: list[Entry]) -> datasets.Dataset:
table = pyarrow.Table.from_pylist(
[
{
"id": entry.get("id"),
"smiles": entry["smiles"],
"coords": torch.tensor(entry["coords"]).flatten().tolist(),
"box_vectors": None
if entry.get("box_vectors") is None
else torch.tensor(entry["box_vectors"]).flatten().tolist(),
"energy": torch.tensor(entry["energy"]).flatten().tolist(),
"forces": torch.tensor(entry["forces"]).flatten().tolist(),
}
Expand Down Expand Up @@ -118,9 +131,27 @@ def predict(
coords = (
(coords_flat.reshape(len(energy_ref), -1, 3)).detach().requires_grad_(True)
)
box_vectors = entry.get("box_vectors", None)

topology = topologies[smiles]

energy_pred = smee.compute_energy(topology, force_field, coords)
if box_vectors is not None:
# smee does not support batched periodic evaluations,
# so we loop over conformers.
box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape(
len(energy_ref), 3, 3
)
energy_pred = torch.cat(
[
smee.compute_energy(
topology, force_field, coords[i], box_vectors[i]
)
for i in range(len(energy_ref))
]
)
else:
energy_pred = smee.compute_energy(topology, force_field, coords, None)

forces_pred = -torch.autograd.grad(
energy_pred.sum(),
coords,
Expand Down
100 changes: 90 additions & 10 deletions descent/tests/targets/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,28 @@ def mock_hoh_entry() -> Entry:
}


def test_create_dataset(mock_meoh_entry):
@pytest.mark.parametrize(
"box_vectors",
[None, torch.eye(3).repeat(2, 1, 1) * 20.0],
ids=["non-periodic", "periodic"],
)
def test_create_dataset(mock_meoh_entry, box_vectors):
entry = {**mock_meoh_entry, "box_vectors": box_vectors}

expected_entries = [
{
"smiles": mock_meoh_entry["smiles"],
"coords": pytest.approx(mock_meoh_entry["coords"].flatten()),
"energy": pytest.approx(mock_meoh_entry["energy"]),
"forces": pytest.approx(mock_meoh_entry["forces"].flatten()),
"id": None,
"smiles": entry["smiles"],
"coords": pytest.approx(entry["coords"].flatten()),
"energy": pytest.approx(entry["energy"]),
"forces": pytest.approx(entry["forces"].flatten()),
"box_vectors": None
if box_vectors is None
else pytest.approx(box_vectors.flatten()),
},
]

dataset = create_dataset([mock_meoh_entry])
dataset = create_dataset([entry])
assert len(dataset) == 1

entries = list(descent.utils.dataset.iter_dataset(dataset))
Expand All @@ -62,11 +73,12 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):


@pytest.mark.parametrize(
"reference, normalize,"
"box_vectors, reference, normalize,"
"expected_energy_ref, expected_forces_ref, "
"expected_energy_pred, expected_forces_pred",
[
(
pytest.param(
None,
"mean",
True,
torch.tensor([-0.5, 0.5]) / math.sqrt(2.0),
Expand Down Expand Up @@ -95,8 +107,10 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
dtype=torch.float64,
)
/ math.sqrt(6.0 * 3.0),
id="non-periodic-mean-normalized",
),
(
pytest.param(
None,
"min",
False,
torch.tensor([0.0, 1.0]),
Expand All @@ -123,10 +137,73 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
],
dtype=torch.float64,
),
id="non-periodic-min",
),
pytest.param(
torch.eye(3).repeat(2, 1, 1) * 30.0,
"mean",
True,
torch.tensor([-0.5, 0.5]) / math.sqrt(2.0),
torch.tensor(
[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
],
dtype=torch.float64,
)
/ math.sqrt(6.0 * 3.0),
torch.tensor([5.585737228393555, -5.585737705230713]),
torch.tensor(
[
[0.0, -19.695229476617897, 0.0],
[38.04311560258793, 9.847614738308948, 0.0],
[-38.04311560258793, 9.847614738308948, 0.0],
[0.0, 32.3990898002703, 0.0],
[-24.190123962094730, -16.19954490013515, 0.0],
[24.190123962094730, -16.19954490013515, 0.0],
],
dtype=torch.float64,
),
id="periodic-mean-normalized",
),
pytest.param(
torch.eye(3).repeat(2, 1, 1) * 30.0,
"min",
False,
torch.tensor([0.0, 1.0]),
torch.tensor(
[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
],
dtype=torch.float64,
),
torch.tensor([0.0, -15.79885196685791]),
torch.tensor(
[
[0.0, -83.55977630615234, 0.0],
[161.40325927734375, 41.77988815307617, 0.0],
[-161.40325927734375, 41.77988815307617, 0.0],
[0.0, 137.4576873779297, 0.0],
[-102.62999725341797, -68.72884368896484, 0.0],
[102.62999725341797, -68.72884368896484, 0.0],
],
dtype=torch.float64,
),
id="periodic-min",
),
],
)
def test_predict(
box_vectors,
reference,
normalize,
expected_energy_ref,
Expand All @@ -135,7 +212,10 @@ def test_predict(
expected_forces_pred,
mock_hoh_entry,
):
dataset = create_dataset([mock_hoh_entry])
entry = {**mock_hoh_entry}
if box_vectors is not None:
entry["box_vectors"] = box_vectors
dataset = create_dataset([entry])

force_field, [topology] = smee.converters.convert_interchange(
openff.interchange.Interchange.from_smirnoff(
Expand Down
Loading