diff --git a/mace/calculators/__init__.py b/mace/calculators/__init__.py index cb2a16f88..3315275ad 100644 --- a/mace/calculators/__init__.py +++ b/mace/calculators/__init__.py @@ -1,9 +1,11 @@ from .foundations_models import mace_anicc, mace_mp, mace_off, mace_omol, mace_polar from .lammps_mace import LAMMPS_MACE from .mace import MACECalculator +from .mace_torchsim import MaceTorchSimModel __all__ = [ "MACECalculator", + "MaceTorchSimModel", "LAMMPS_MACE", "mace_mp", "mace_off", diff --git a/tests/test_torchsim.py b/tests/test_torchsim.py index 68ec7b752..9ef0d9548 100644 --- a/tests/test_torchsim.py +++ b/tests/test_torchsim.py @@ -8,14 +8,19 @@ import numpy as np import pytest import torch -from ase import build + +from mace.calculators import mace_mp, mace_off try: import torch_sim as ts - - TORCHSIM_AVAILABLE = True -except ImportError: - TORCHSIM_AVAILABLE = False + from torch_sim.models.interface import validate_model_outputs + from torch_sim.testing import ( + SIMSTATE_BULK_GENERATORS, + SIMSTATE_MOLECULE_GENERATORS, + assert_model_calculator_consistency, + ) +except (ImportError, ModuleNotFoundError): + pytest.skip("Skipping torch-sim tests due to ImportError", allow_module_level=True) try: import cuequivariance as cue # noqa: F401 @@ -24,13 +29,95 @@ except ImportError: CUET_AVAILABLE = False -pytestmark = pytest.mark.skipif( - not TORCHSIM_AVAILABLE, reason="torch-sim not installed" -) +from mace.calculators.mace_torchsim import MaceTorchSimModel pytest_mace_dir = Path(__file__).parent.parent run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" +DEVICE = torch.device("cpu") +DTYPE = torch.float64 +MACE_MP_MODEL = "small-0b" +MACE_OFF_MODEL = "small" + + +def _to_dtype_name(dtype: torch.dtype) -> str: + if dtype == torch.float32: + return "float32" + if dtype == torch.float64: + return "float64" + raise ValueError(f"Unsupported dtype {dtype}") + + +# --------------------------------------------------------------------------- +# Fixtures for foundation-model tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def raw_mace_mp_model(): + return mace_mp( + model=MACE_MP_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + return_raw_model=True, + ) + + +@pytest.fixture(scope="module") +def raw_mace_off_model(): + return mace_off( + model=MACE_OFF_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + return_raw_model=True, + ) + + +@pytest.fixture +def ase_mace_mp_calculator(): + return mace_mp( + model=MACE_MP_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + dispersion=False, + ) + + +@pytest.fixture +def ase_mace_off_calculator(): + return mace_off( + model=MACE_OFF_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + ) + + +@pytest.fixture +def ts_mace_mp_model(raw_mace_mp_model): + return MaceTorchSimModel( + model=raw_mace_mp_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def ts_mace_off_model(raw_mace_off_model): + return MaceTorchSimModel( + model=raw_mace_off_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=False, + ) + + +# --------------------------------------------------------------------------- +# Fixtures for locally-trained-model tests +# --------------------------------------------------------------------------- + @pytest.fixture(scope="module") def trained_model_path(tmp_path_factory): @@ -123,6 +210,71 @@ def water_atoms(): return atoms +# --------------------------------------------------------------------------- +# Foundation-model tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("sim_state_name", ("si_sim_state", "rattled_si_sim_state")) +def test_torch_sim_mace_mp_consistency( + sim_state_name, ts_mace_mp_model, ase_mace_mp_calculator +): + sim_state = SIMSTATE_BULK_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency( + model=ts_mace_mp_model, + calculator=ase_mace_mp_calculator, + sim_state=sim_state, + ) + + +@pytest.mark.parametrize("sim_state_name", ("benzene_sim_state",)) +def test_torch_sim_mace_off_consistency( + sim_state_name, ts_mace_off_model, ase_mace_off_calculator +): + sim_state = SIMSTATE_MOLECULE_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency( + model=ts_mace_off_model, + calculator=ase_mace_off_calculator, + sim_state=sim_state, + ) + + +@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) +def test_torch_sim_mace_dtype_smoke(raw_mace_mp_model, dtype: torch.dtype): + model = MaceTorchSimModel( + model=raw_mace_mp_model, + device=DEVICE, + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) + state = SIMSTATE_BULK_GENERATORS["si_sim_state"](DEVICE, dtype) + output = model(state) + + assert output["energy"].shape == (1,) + assert torch.is_floating_point(output["energy"]) + assert output["forces"].shape == state.positions.shape + assert torch.is_floating_point(output["forces"]) + assert output["stress"].shape == (1, 3, 3) + + +def test_torch_sim_mace_off_output_keys(ts_mace_off_model): + state = SIMSTATE_MOLECULE_GENERATORS["benzene_sim_state"](DEVICE, DTYPE) + output = ts_mace_off_model(state) + assert "energy" in output + assert "forces" in output + assert "stress" not in output + + +def test_torch_sim_mace_validate_outputs(ts_mace_mp_model): + validate_model_outputs(ts_mace_mp_model, DEVICE, DTYPE) + + +# --------------------------------------------------------------------------- +# Locally-trained-model tests +# --------------------------------------------------------------------------- + + def test_torchsim_basic(trained_model_path, water_atoms): from mace.calculators.mace_torchsim import MaceTorchSimModel