Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import ase.build
import ase.units
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
import torch
from metatomic.torch.ase_calculator import MetatomicCalculator
from flashmd.ase import EnergyCalculator

from flashmd import get_pretrained
from flashmd.ase.langevin import Langevin
Expand All @@ -44,8 +44,8 @@ atoms.set_velocities( # it is generally a good idea to remove any net velocity
device="cuda" if torch.cuda.is_available() else "cpu"
energy_model, flashmd_model = get_pretrained("pet-omatpes-v2", time_step)

# Set the energy model (see below for more precise usage)
calculator = MetatomicCalculator(energy_model, device=device)
# Set the energy model (optional, see below for more precise usage)
calculator = EnergyCalculator(energy_model, device=device)
atoms.calc = calculator

# Run MD
Expand Down
8 changes: 4 additions & 4 deletions docs/energy.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ running FlashMD, exactly as shown in the opening example (and below with the mor
``do_gradients_with_energy=False`` which will save you memory and computation):

```
from metatomic.torch.ase_calculator import MetatomicCalculator
from flashmd.ase import EnergyCalculator

... # setting up atoms
calculator = MetatomicCalculator(energy_model, device=device, do_gradients_with_energy=False)
calculator = EnergyCalculator(energy_model, device=device, do_gradients_with_energy=False)
atoms.calc = calculator
... # running FlashMD
```
Expand All @@ -31,10 +31,10 @@ with traditional MD. Then, you can just use ASE's MD modules as usual after atta
the energy calculator:

```
from metatomic.torch.ase_calculator import MetatomicCalculator
from flashmd.ase import EnergyCalculator

... # setting up atoms
calculator = MetatomicCalculator(energy_model, device=device)
calculator = EnergyCalculator(energy_model, device=device)
atoms.calc = calculator
... # running MD
```
Expand Down
26 changes: 26 additions & 0 deletions src/flashmd/ase/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import tempfile

from metatomic.torch import AtomisticModel
from metatomic.torch.ase_calculator import MetatomicCalculator


class EnergyCalculator(MetatomicCalculator):
"""
ASE calculator for energy predictions using a metatomic AtomisticModel.

Slightly modified to save the model to a temporary file to ensure compatibility
with ase.io.Trajectory.
"""

def __init__(self, model, *args, **kwargs):
# save the model to a path otherwise it won't work with ase.io.Trajectory
# which calls todict on the calculator

if isinstance(model, AtomisticModel):
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as f:
path = f.name
model.save(path)
else:
path = model

super().__init__(path, *args, **kwargs)
29 changes: 29 additions & 0 deletions tests/test_energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import ase.build
import ase.io
import ase.units
import torch
from ase.md import VelocityVerlet

from flashmd import get_pretrained
from flashmd.ase import EnergyCalculator


def test_md(monkeypatch, tmp_path):
"""Test that a short MD run completes without errors with a Trajectory file."""
monkeypatch.chdir(tmp_path)

atoms = ase.build.bulk("Al", "fcc", cubic=True)

time_step = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
energy_model, _ = get_pretrained("pet-omatpes-v2", time_step)
calculator = EnergyCalculator(
energy_model, device=device, do_gradients_with_energy=False
)
atoms.calc = calculator

dyn = VelocityVerlet(atoms=atoms, timestep=time_step * ase.units.fs)
traj = ase.io.Trajectory("test_md.traj", "w", atoms)
dyn.attach(traj.write)
dyn.run(10)
traj.close()