Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mace/calculators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .foundations_models import mace_anicc, mace_mp, mace_off, mace_omol, mace_polar
from .lammps_mace import LAMMPS_MACE
from .mace import MACECalculator
from .mace_torchsim import MaceTorchSimModel

__all__ = [
"MACECalculator",
"MaceTorchSimModel",
"LAMMPS_MACE",
"mace_mp",
"mace_off",
Expand Down
168 changes: 160 additions & 8 deletions tests/test_torchsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
import numpy as np
import pytest
import torch
from ase import build

from mace.calculators import mace_mp, mace_off

try:
import torch_sim as ts

TORCHSIM_AVAILABLE = True
except ImportError:
TORCHSIM_AVAILABLE = False
from torch_sim.models.interface import validate_model_outputs
from torch_sim.testing import (
SIMSTATE_BULK_GENERATORS,
SIMSTATE_MOLECULE_GENERATORS,
assert_model_calculator_consistency,
)
except (ImportError, ModuleNotFoundError):
pytest.skip("Skipping torch-sim tests due to ImportError", allow_module_level=True)

try:
import cuequivariance as cue # noqa: F401
Expand All @@ -24,13 +29,95 @@
except ImportError:
CUET_AVAILABLE = False

pytestmark = pytest.mark.skipif(
not TORCHSIM_AVAILABLE, reason="torch-sim not installed"
)
from mace.calculators.mace_torchsim import MaceTorchSimModel

pytest_mace_dir = Path(__file__).parent.parent
run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py"

DEVICE = torch.device("cpu")
DTYPE = torch.float64
MACE_MP_MODEL = "small-0b"
MACE_OFF_MODEL = "small"


def _to_dtype_name(dtype: torch.dtype) -> str:
if dtype == torch.float32:
return "float32"
if dtype == torch.float64:
return "float64"
raise ValueError(f"Unsupported dtype {dtype}")


# ---------------------------------------------------------------------------
# Fixtures for foundation-model tests
# ---------------------------------------------------------------------------


@pytest.fixture(scope="module")
def raw_mace_mp_model():
return mace_mp(
model=MACE_MP_MODEL,
device=str(DEVICE),
default_dtype=_to_dtype_name(DTYPE),
return_raw_model=True,
)


@pytest.fixture(scope="module")
def raw_mace_off_model():
return mace_off(
model=MACE_OFF_MODEL,
device=str(DEVICE),
default_dtype=_to_dtype_name(DTYPE),
return_raw_model=True,
)


@pytest.fixture
def ase_mace_mp_calculator():
return mace_mp(
model=MACE_MP_MODEL,
device=str(DEVICE),
default_dtype=_to_dtype_name(DTYPE),
dispersion=False,
)


@pytest.fixture
def ase_mace_off_calculator():
return mace_off(
model=MACE_OFF_MODEL,
device=str(DEVICE),
default_dtype=_to_dtype_name(DTYPE),
)


@pytest.fixture
def ts_mace_mp_model(raw_mace_mp_model):
return MaceTorchSimModel(
model=raw_mace_mp_model,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def ts_mace_off_model(raw_mace_off_model):
return MaceTorchSimModel(
model=raw_mace_off_model,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=False,
)


# ---------------------------------------------------------------------------
# Fixtures for locally-trained-model tests
# ---------------------------------------------------------------------------


@pytest.fixture(scope="module")
def trained_model_path(tmp_path_factory):
Expand Down Expand Up @@ -123,6 +210,71 @@ def water_atoms():
return atoms


# ---------------------------------------------------------------------------
# Foundation-model tests
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("sim_state_name", ("si_sim_state", "rattled_si_sim_state"))
def test_torch_sim_mace_mp_consistency(
sim_state_name, ts_mace_mp_model, ase_mace_mp_calculator
):
sim_state = SIMSTATE_BULK_GENERATORS[sim_state_name](DEVICE, DTYPE)
assert_model_calculator_consistency(
model=ts_mace_mp_model,
calculator=ase_mace_mp_calculator,
sim_state=sim_state,
)


@pytest.mark.parametrize("sim_state_name", ("benzene_sim_state",))
def test_torch_sim_mace_off_consistency(
sim_state_name, ts_mace_off_model, ase_mace_off_calculator
):
sim_state = SIMSTATE_MOLECULE_GENERATORS[sim_state_name](DEVICE, DTYPE)
assert_model_calculator_consistency(
model=ts_mace_off_model,
calculator=ase_mace_off_calculator,
sim_state=sim_state,
)


@pytest.mark.parametrize("dtype", (torch.float32, torch.float64))
def test_torch_sim_mace_dtype_smoke(raw_mace_mp_model, dtype: torch.dtype):
model = MaceTorchSimModel(
model=raw_mace_mp_model,
device=DEVICE,
dtype=dtype,
compute_forces=True,
compute_stress=True,
)
state = SIMSTATE_BULK_GENERATORS["si_sim_state"](DEVICE, dtype)
output = model(state)

assert output["energy"].shape == (1,)
assert torch.is_floating_point(output["energy"])
assert output["forces"].shape == state.positions.shape
assert torch.is_floating_point(output["forces"])
assert output["stress"].shape == (1, 3, 3)


def test_torch_sim_mace_off_output_keys(ts_mace_off_model):
state = SIMSTATE_MOLECULE_GENERATORS["benzene_sim_state"](DEVICE, DTYPE)
output = ts_mace_off_model(state)
assert "energy" in output
assert "forces" in output
assert "stress" not in output


def test_torch_sim_mace_validate_outputs(ts_mace_mp_model):
validate_model_outputs(ts_mace_mp_model, DEVICE, DTYPE)


# ---------------------------------------------------------------------------
# Locally-trained-model tests
# ---------------------------------------------------------------------------


def test_torchsim_basic(trained_model_path, water_atoms):
from mace.calculators.mace_torchsim import MaceTorchSimModel

Expand Down
Loading