diff --git a/descent/targets/energy.py b/descent/targets/energy.py index b9f89d1..6a53535 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -11,8 +11,10 @@ DATA_SCHEMA = pyarrow.schema( [ + ("id", pyarrow.string()), ("smiles", pyarrow.string()), ("coords", pyarrow.list_(pyarrow.float64())), + ("box_vectors", pyarrow.list_(pyarrow.float64())), ("energy", pyarrow.list_(pyarrow.float64())), ("forces", pyarrow.list_(pyarrow.float64())), ] @@ -22,6 +24,9 @@ class Entry(typing.TypedDict): """Represents a set of reference energies and forces.""" + id: str | None + """An optional identifier for the entry (e.g. a run name). Defaults to ``None``.""" + smiles: str """The indexed SMILES description of the molecule the energies and forces were computed for.""" @@ -34,6 +39,10 @@ class Entry(typing.TypedDict): forces: torch.Tensor """The reference forces [kcal/mol/Å] with ``shape=(n_confs, n_particles, 3)``.""" + box_vectors: torch.Tensor | None + """The box vectors [Å] for periodic systems with ``shape=(n_confs, 3, 3)``, or + ``None`` for non-periodic systems.""" + def create_dataset(entries: list[Entry]) -> datasets.Dataset: """Create a dataset from a list of existing entries. @@ -48,8 +57,12 @@ def create_dataset(entries: list[Entry]) -> datasets.Dataset: table = pyarrow.Table.from_pylist( [ { + "id": entry.get("id"), "smiles": entry["smiles"], "coords": torch.tensor(entry["coords"]).flatten().tolist(), + "box_vectors": None + if entry.get("box_vectors") is None + else torch.tensor(entry["box_vectors"]).flatten().tolist(), "energy": torch.tensor(entry["energy"]).flatten().tolist(), "forces": torch.tensor(entry["forces"]).flatten().tolist(), } @@ -118,9 +131,27 @@ def predict( coords = ( (coords_flat.reshape(len(energy_ref), -1, 3)).detach().requires_grad_(True) ) + box_vectors = entry.get("box_vectors", None) + topology = topologies[smiles] - energy_pred = smee.compute_energy(topology, force_field, coords) + if box_vectors is not None: + # smee does not support batched periodic evaluations, + # so we loop over conformers. + box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape( + len(energy_ref), 3, 3 + ) + energy_pred = torch.cat( + [ + smee.compute_energy( + topology, force_field, coords[i], box_vectors[i] + ) + for i in range(len(energy_ref)) + ] + ) + else: + energy_pred = smee.compute_energy(topology, force_field, coords, None) + forces_pred = -torch.autograd.grad( energy_pred.sum(), coords, diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py index e88860b..b3c0732 100644 --- a/descent/tests/targets/test_energy.py +++ b/descent/tests/targets/test_energy.py @@ -35,17 +35,28 @@ def mock_hoh_entry() -> Entry: } -def test_create_dataset(mock_meoh_entry): +@pytest.mark.parametrize( + "box_vectors", + [None, torch.eye(3).repeat(2, 1, 1) * 20.0], + ids=["non-periodic", "periodic"], +) +def test_create_dataset(mock_meoh_entry, box_vectors): + entry = {**mock_meoh_entry, "box_vectors": box_vectors} + expected_entries = [ { - "smiles": mock_meoh_entry["smiles"], - "coords": pytest.approx(mock_meoh_entry["coords"].flatten()), - "energy": pytest.approx(mock_meoh_entry["energy"]), - "forces": pytest.approx(mock_meoh_entry["forces"].flatten()), + "id": None, + "smiles": entry["smiles"], + "coords": pytest.approx(entry["coords"].flatten()), + "energy": pytest.approx(entry["energy"]), + "forces": pytest.approx(entry["forces"].flatten()), + "box_vectors": None + if box_vectors is None + else pytest.approx(box_vectors.flatten()), }, ] - dataset = create_dataset([mock_meoh_entry]) + dataset = create_dataset([entry]) assert len(dataset) == 1 entries = list(descent.utils.dataset.iter_dataset(dataset)) @@ -62,11 +73,12 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): @pytest.mark.parametrize( - "reference, normalize," + "box_vectors, reference, normalize," "expected_energy_ref, expected_forces_ref, " "expected_energy_pred, expected_forces_pred", [ - ( + pytest.param( + None, "mean", True, torch.tensor([-0.5, 0.5]) / math.sqrt(2.0), @@ -95,8 +107,10 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): dtype=torch.float64, ) / math.sqrt(6.0 * 3.0), + id="non-periodic-mean-normalized", ), - ( + pytest.param( + None, "min", False, torch.tensor([0.0, 1.0]), @@ -123,10 +137,73 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): ], dtype=torch.float64, ), + id="non-periodic-min", + ), + pytest.param( + torch.eye(3).repeat(2, 1, 1) * 30.0, + "mean", + True, + torch.tensor([-0.5, 0.5]) / math.sqrt(2.0), + torch.tensor( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + ], + dtype=torch.float64, + ) + / math.sqrt(6.0 * 3.0), + torch.tensor([5.585737228393555, -5.585737705230713]), + torch.tensor( + [ + [0.0, -19.695229476617897, 0.0], + [38.04311560258793, 9.847614738308948, 0.0], + [-38.04311560258793, 9.847614738308948, 0.0], + [0.0, 32.3990898002703, 0.0], + [-24.190123962094730, -16.19954490013515, 0.0], + [24.190123962094730, -16.19954490013515, 0.0], + ], + dtype=torch.float64, + ), + id="periodic-mean-normalized", + ), + pytest.param( + torch.eye(3).repeat(2, 1, 1) * 30.0, + "min", + False, + torch.tensor([0.0, 1.0]), + torch.tensor( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + ], + dtype=torch.float64, + ), + torch.tensor([0.0, -15.79885196685791]), + torch.tensor( + [ + [0.0, -83.55977630615234, 0.0], + [161.40325927734375, 41.77988815307617, 0.0], + [-161.40325927734375, 41.77988815307617, 0.0], + [0.0, 137.4576873779297, 0.0], + [-102.62999725341797, -68.72884368896484, 0.0], + [102.62999725341797, -68.72884368896484, 0.0], + ], + dtype=torch.float64, + ), + id="periodic-min", ), ], ) def test_predict( + box_vectors, reference, normalize, expected_energy_ref, @@ -135,7 +212,10 @@ def test_predict( expected_forces_pred, mock_hoh_entry, ): - dataset = create_dataset([mock_hoh_entry]) + entry = {**mock_hoh_entry} + if box_vectors is not None: + entry["box_vectors"] = box_vectors + dataset = create_dataset([entry]) force_field, [topology] = smee.converters.convert_interchange( openff.interchange.Interchange.from_smirnoff(