From f8a023dda177c43e9f3a1a455f628b754f984b37 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 23 Mar 2026 09:48:50 +0000 Subject: [PATCH 1/6] Add support for periodic systems in energy target. --- descent/targets/energy.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/descent/targets/energy.py b/descent/targets/energy.py index b9f89d1..6c136e6 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -13,6 +13,7 @@ [ ("smiles", pyarrow.string()), ("coords", pyarrow.list_(pyarrow.float64())), + ("box_vectors", pyarrow.list_(pyarrow.float64())), ("energy", pyarrow.list_(pyarrow.float64())), ("forces", pyarrow.list_(pyarrow.float64())), ] @@ -34,6 +35,10 @@ class Entry(typing.TypedDict): forces: torch.Tensor """The reference forces [kcal/mol/Å] with ``shape=(n_confs, n_particles, 3)``.""" + box_vectors: torch.Tensor | None + """The box vectors [Å] for periodic systems with ``shape=(3, 3)``, or ``None`` + for non-periodic systems.""" + def create_dataset(entries: list[Entry]) -> datasets.Dataset: """Create a dataset from a list of existing entries. @@ -50,6 +55,9 @@ def create_dataset(entries: list[Entry]) -> datasets.Dataset: { "smiles": entry["smiles"], "coords": torch.tensor(entry["coords"]).flatten().tolist(), + "box_vectors": None + if entry.get("box_vectors") is None + else torch.tensor(entry["box_vectors"]).flatten().tolist(), "energy": torch.tensor(entry["energy"]).flatten().tolist(), "forces": torch.tensor(entry["forces"]).flatten().tolist(), } @@ -118,9 +126,14 @@ def predict( coords = ( (coords_flat.reshape(len(energy_ref), -1, 3)).detach().requires_grad_(True) ) + box_vectors = entry.get("box_vectors", None) + + if box_vectors is not None: + box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape(3, 3) + topology = topologies[smiles] - energy_pred = smee.compute_energy(topology, force_field, coords) + energy_pred = smee.compute_energy(topology, force_field, coords, box_vectors) forces_pred = -torch.autograd.grad( energy_pred.sum(), coords, From 0711a416bf2d69425149ea840565f28d1fb9735a Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 23 Mar 2026 09:49:19 +0000 Subject: [PATCH 2/6] Test energy prediction and dataset creation for periodic systems. --- descent/tests/targets/test_energy.py | 99 +++++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 10 deletions(-) diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py index e88860b..cd73945 100644 --- a/descent/tests/targets/test_energy.py +++ b/descent/tests/targets/test_energy.py @@ -35,17 +35,27 @@ def mock_hoh_entry() -> Entry: } -def test_create_dataset(mock_meoh_entry): +@pytest.mark.parametrize( + "box_vectors", + [None, torch.eye(3) * 20.0], + ids=["non-periodic", "periodic"], +) +def test_create_dataset(mock_meoh_entry, box_vectors): + entry = {**mock_meoh_entry, "box_vectors": box_vectors} + expected_entries = [ { - "smiles": mock_meoh_entry["smiles"], - "coords": pytest.approx(mock_meoh_entry["coords"].flatten()), - "energy": pytest.approx(mock_meoh_entry["energy"]), - "forces": pytest.approx(mock_meoh_entry["forces"].flatten()), + "smiles": entry["smiles"], + "coords": pytest.approx(entry["coords"].flatten()), + "energy": pytest.approx(entry["energy"]), + "forces": pytest.approx(entry["forces"].flatten()), + "box_vectors": None + if box_vectors is None + else pytest.approx(box_vectors.flatten()), }, ] - dataset = create_dataset([mock_meoh_entry]) + dataset = create_dataset([entry]) assert len(dataset) == 1 entries = list(descent.utils.dataset.iter_dataset(dataset)) @@ -62,11 +72,12 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): @pytest.mark.parametrize( - "reference, normalize," + "box_vectors, reference, normalize," "expected_energy_ref, expected_forces_ref, " "expected_energy_pred, expected_forces_pred", [ - ( + pytest.param( + None, "mean", True, torch.tensor([-0.5, 0.5]) / math.sqrt(2.0), @@ -95,8 +106,10 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): dtype=torch.float64, ) / math.sqrt(6.0 * 3.0), + id="non-periodic-mean-normalized", ), - ( + pytest.param( + None, "min", False, torch.tensor([0.0, 1.0]), @@ -123,10 +136,73 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): ], dtype=torch.float64, ), + id="non-periodic-min", + ), + pytest.param( + torch.eye(3) * 30.0, + "mean", + True, + torch.tensor([-0.5, 0.5]) / math.sqrt(2.0), + torch.tensor( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + ], + dtype=torch.float64, + ) + / math.sqrt(6.0 * 3.0), + torch.tensor([5.585737228393555, -5.585738658905029]), + torch.tensor( + [ + [0.0, -19.695231274883554, 0.0], + [38.04311560258793, 9.847614738308948, 0.0], + [-38.04311560258793, 9.847614738308948, 0.0], + [0.0, 32.39909339680162, 0.0], + [-24.190123962094730, -16.19954490013515, 0.0], + [24.190123962094730, -16.19954490013515, 0.0], + ], + dtype=torch.float64, + ), + id="periodic-mean-normalized", + ), + pytest.param( + torch.eye(3) * 30.0, + "min", + False, + torch.tensor([0.0, 1.0]), + torch.tensor( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + ], + dtype=torch.float64, + ), + torch.tensor([0.0, -15.798852920532227]), + torch.tensor( + [ + [0.0, -83.55978393554688, 0.0], + [161.40325927734375, 41.77988815307617, 0.0], + [-161.40325927734375, 41.77988815307617, 0.0], + [0.0, 137.45770263671875, 0.0], + [-102.62999725341797, -68.72884368896484, 0.0], + [102.62999725341797, -68.72884368896484, 0.0], + ], + dtype=torch.float64, + ), + id="periodic-min", ), ], ) def test_predict( + box_vectors, reference, normalize, expected_energy_ref, @@ -135,7 +211,10 @@ def test_predict( expected_forces_pred, mock_hoh_entry, ): - dataset = create_dataset([mock_hoh_entry]) + entry = {**mock_hoh_entry} + if box_vectors is not None: + entry["box_vectors"] = box_vectors + dataset = create_dataset([entry]) force_field, [topology] = smee.converters.convert_interchange( openff.interchange.Interchange.from_smirnoff( From a1b6b65a51cdc5cef7b65a0ad1cf1110dfaf6e66 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 23 Mar 2026 10:08:15 +0000 Subject: [PATCH 3/6] Add batch dimension --- descent/targets/energy.py | 22 ++++++++++++++++------ descent/tests/targets/test_energy.py | 18 +++++++++--------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/descent/targets/energy.py b/descent/targets/energy.py index 6c136e6..08dee09 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -36,8 +36,8 @@ class Entry(typing.TypedDict): """The reference forces [kcal/mol/Å] with ``shape=(n_confs, n_particles, 3)``.""" box_vectors: torch.Tensor | None - """The box vectors [Å] for periodic systems with ``shape=(3, 3)``, or ``None`` - for non-periodic systems.""" + """The box vectors [Å] for periodic systems with ``shape=(n_confs, 3, 3)``, or + ``None`` for non-periodic systems.""" def create_dataset(entries: list[Entry]) -> datasets.Dataset: @@ -128,12 +128,22 @@ def predict( ) box_vectors = entry.get("box_vectors", None) - if box_vectors is not None: - box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape(3, 3) - topology = topologies[smiles] - energy_pred = smee.compute_energy(topology, force_field, coords, box_vectors) + if box_vectors is not None: + # smee does not support batched periodic evaluations, so we loop over conformers. + box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape( + len(energy_ref), 3, 3 + ) + energy_pred = torch.cat( + [ + smee.compute_energy(topology, force_field, coords[i], box_vectors[i]) + for i in range(len(energy_ref)) + ] + ) + else: + energy_pred = smee.compute_energy(topology, force_field, coords, None) + forces_pred = -torch.autograd.grad( energy_pred.sum(), coords, diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py index cd73945..8d58c6a 100644 --- a/descent/tests/targets/test_energy.py +++ b/descent/tests/targets/test_energy.py @@ -37,7 +37,7 @@ def mock_hoh_entry() -> Entry: @pytest.mark.parametrize( "box_vectors", - [None, torch.eye(3) * 20.0], + [None, torch.eye(3).repeat(2, 1, 1) * 20.0], ids=["non-periodic", "periodic"], ) def test_create_dataset(mock_meoh_entry, box_vectors): @@ -139,7 +139,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): id="non-periodic-min", ), pytest.param( - torch.eye(3) * 30.0, + torch.eye(3).repeat(2, 1, 1) * 30.0, "mean", True, torch.tensor([-0.5, 0.5]) / math.sqrt(2.0), @@ -155,13 +155,13 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): dtype=torch.float64, ) / math.sqrt(6.0 * 3.0), - torch.tensor([5.585737228393555, -5.585738658905029]), + torch.tensor([5.585737228393555, -5.585737705230713]), torch.tensor( [ - [0.0, -19.695231274883554, 0.0], + [0.0, -19.695229476617897, 0.0], [38.04311560258793, 9.847614738308948, 0.0], [-38.04311560258793, 9.847614738308948, 0.0], - [0.0, 32.39909339680162, 0.0], + [0.0, 32.3990898002703, 0.0], [-24.190123962094730, -16.19954490013515, 0.0], [24.190123962094730, -16.19954490013515, 0.0], ], @@ -170,7 +170,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): id="periodic-mean-normalized", ), pytest.param( - torch.eye(3) * 30.0, + torch.eye(3).repeat(2, 1, 1) * 30.0, "min", False, torch.tensor([0.0, 1.0]), @@ -185,13 +185,13 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): ], dtype=torch.float64, ), - torch.tensor([0.0, -15.798852920532227]), + torch.tensor([0.0, -15.79885196685791]), torch.tensor( [ - [0.0, -83.55978393554688, 0.0], + [0.0, -83.55977630615234, 0.0], [161.40325927734375, 41.77988815307617, 0.0], [-161.40325927734375, 41.77988815307617, 0.0], - [0.0, 137.45770263671875, 0.0], + [0.0, 137.4576873779297, 0.0], [-102.62999725341797, -68.72884368896484, 0.0], [102.62999725341797, -68.72884368896484, 0.0], ], From 211b51afcaff04854d302b4b5e8d33f958ba244d Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Tue, 24 Mar 2026 14:45:05 +0000 Subject: [PATCH 4/6] Add optional identifier to dataset entries. --- descent/targets/energy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/descent/targets/energy.py b/descent/targets/energy.py index 08dee09..62cab10 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -11,6 +11,7 @@ DATA_SCHEMA = pyarrow.schema( [ + ("id", pyarrow.string()), ("smiles", pyarrow.string()), ("coords", pyarrow.list_(pyarrow.float64())), ("box_vectors", pyarrow.list_(pyarrow.float64())), @@ -23,6 +24,9 @@ class Entry(typing.TypedDict): """Represents a set of reference energies and forces.""" + id: str | None + """An optional identifier for the entry (e.g. a run name). Defaults to ``None``.""" + smiles: str """The indexed SMILES description of the molecule the energies and forces were computed for.""" @@ -53,6 +57,7 @@ def create_dataset(entries: list[Entry]) -> datasets.Dataset: table = pyarrow.Table.from_pylist( [ { + "id": entry.get("id"), "smiles": entry["smiles"], "coords": torch.tensor(entry["coords"]).flatten().tolist(), "box_vectors": None From 0da9ecd56fbc89f1aec2e0b583f662059f2ada78 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Tue, 24 Mar 2026 17:18:33 +0000 Subject: [PATCH 5/6] Add optional "id" field to dataset entries in test_create_dataset --- descent/tests/targets/test_energy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py index 8d58c6a..b3c0732 100644 --- a/descent/tests/targets/test_energy.py +++ b/descent/tests/targets/test_energy.py @@ -45,6 +45,7 @@ def test_create_dataset(mock_meoh_entry, box_vectors): expected_entries = [ { + "id": None, "smiles": entry["smiles"], "coords": pytest.approx(entry["coords"].flatten()), "energy": pytest.approx(entry["energy"]), From 3afeb7b681a70745e7db4d96087e6e2c9bad99b2 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Wed, 25 Mar 2026 11:59:28 +0000 Subject: [PATCH 6/6] Ruff formatting. --- descent/targets/energy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/descent/targets/energy.py b/descent/targets/energy.py index 62cab10..6a53535 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -136,13 +136,16 @@ def predict( topology = topologies[smiles] if box_vectors is not None: - # smee does not support batched periodic evaluations, so we loop over conformers. + # smee does not support batched periodic evaluations, + # so we loop over conformers. box_vectors = smee.utils.tensor_like(box_vectors, coords_flat).reshape( len(energy_ref), 3, 3 ) energy_pred = torch.cat( [ - smee.compute_energy(topology, force_field, coords[i], box_vectors[i]) + smee.compute_energy( + topology, force_field, coords[i], box_vectors[i] + ) for i in range(len(energy_ref)) ] )