From 2cc5b3281a8a846bbc3f50bf76619466dd05410d Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sun, 11 Jan 2026 18:59:07 +0100 Subject: [PATCH] Fix `ase.io.Trajectory` compatibility --- README.md | 6 +++--- docs/energy.md | 8 ++++---- src/flashmd/ase/__init__.py | 26 ++++++++++++++++++++++++++ tests/test_energy.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 tests/test_energy.py diff --git a/README.md b/README.md index 0ed38a0..6a13537 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ import ase.build import ase.units from ase.md.velocitydistribution import MaxwellBoltzmannDistribution import torch -from metatomic.torch.ase_calculator import MetatomicCalculator +from flashmd.ase import EnergyCalculator from flashmd import get_pretrained from flashmd.ase.langevin import Langevin @@ -44,8 +44,8 @@ atoms.set_velocities( # it is generally a good idea to remove any net velocity device="cuda" if torch.cuda.is_available() else "cpu" energy_model, flashmd_model = get_pretrained("pet-omatpes-v2", time_step) -# Set the energy model (see below for more precise usage) -calculator = MetatomicCalculator(energy_model, device=device) +# Set the energy model (optional, see below for more precise usage) +calculator = EnergyCalculator(energy_model, device=device) atoms.calc = calculator # Run MD diff --git a/docs/energy.md b/docs/energy.md index 672d3f6..8081cc1 100644 --- a/docs/energy.md +++ b/docs/energy.md @@ -14,10 +14,10 @@ running FlashMD, exactly as shown in the opening example (and below with the mor ``do_gradients_with_energy=False`` which will save you memory and computation): ``` -from metatomic.torch.ase_calculator import MetatomicCalculator +from flashmd.ase import EnergyCalculator ... # setting up atoms -calculator = MetatomicCalculator(energy_model, device=device, do_gradients_with_energy=False) +calculator = EnergyCalculator(energy_model, device=device, do_gradients_with_energy=False) atoms.calc = calculator ... # running FlashMD ``` @@ -31,10 +31,10 @@ with traditional MD. Then, you can just use ASE's MD modules as usual after atta the energy calculator: ``` -from metatomic.torch.ase_calculator import MetatomicCalculator +from flashmd.ase import EnergyCalculator ... # setting up atoms -calculator = MetatomicCalculator(energy_model, device=device) +calculator = EnergyCalculator(energy_model, device=device) atoms.calc = calculator ... # running MD ``` diff --git a/src/flashmd/ase/__init__.py b/src/flashmd/ase/__init__.py index e69de29..812b19f 100644 --- a/src/flashmd/ase/__init__.py +++ b/src/flashmd/ase/__init__.py @@ -0,0 +1,26 @@ +import tempfile + +from metatomic.torch import AtomisticModel +from metatomic.torch.ase_calculator import MetatomicCalculator + + +class EnergyCalculator(MetatomicCalculator): + """ + ASE calculator for energy predictions using a metatomic AtomisticModel. + + Slightly modified to save the model to a temporary file to ensure compatibility + with ase.io.Trajectory. + """ + + def __init__(self, model, *args, **kwargs): + # save the model to a path otherwise it won't work with ase.io.Trajectory + # which calls todict on the calculator + + if isinstance(model, AtomisticModel): + with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as f: + path = f.name + model.save(path) + else: + path = model + + super().__init__(path, *args, **kwargs) diff --git a/tests/test_energy.py b/tests/test_energy.py new file mode 100644 index 0000000..0c5cb8d --- /dev/null +++ b/tests/test_energy.py @@ -0,0 +1,29 @@ +import ase.build +import ase.io +import ase.units +import torch +from ase.md import VelocityVerlet + +from flashmd import get_pretrained +from flashmd.ase import EnergyCalculator + + +def test_md(monkeypatch, tmp_path): + """Test that a short MD run completes without errors with a Trajectory file.""" + monkeypatch.chdir(tmp_path) + + atoms = ase.build.bulk("Al", "fcc", cubic=True) + + time_step = 64 + device = "cuda" if torch.cuda.is_available() else "cpu" + energy_model, _ = get_pretrained("pet-omatpes-v2", time_step) + calculator = EnergyCalculator( + energy_model, device=device, do_gradients_with_energy=False + ) + atoms.calc = calculator + + dyn = VelocityVerlet(atoms=atoms, timestep=time_step * ase.units.fs) + traj = ase.io.Trajectory("test_md.traj", "w", atoms) + dyn.attach(traj.write) + dyn.run(10) + traj.close()