From 21c1871d9a7d4ee7805275d6b52b62a32516f87c Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Fri, 27 Feb 2026 16:39:55 -0800 Subject: [PATCH 1/2] Add first party support for MACE TorchSim Model Interface --- .github/workflows/unittest.yaml | 22 +++ mace/calculators/__init__.py | 2 + mace/calculators/torch_sim.py | 255 ++++++++++++++++++++++++++++++++ setup.cfg | 2 + tests/test_torch_sim.py | 147 ++++++++++++++++++ 5 files changed, 428 insertions(+) create mode 100644 mace/calculators/torch_sim.py create mode 100644 tests/test_torch_sim.py diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 3c87c9b21..362686259 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -106,3 +106,25 @@ jobs: run: | pytest -v tests/test_cueq_oeq.py -k TestCueq pytest -v tests/test_calculator.py + + pytest-torchsim: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: "pip" + + - name: Install requirements alongside torchsim + run: | + pip install -U pip + pip install ".[dev,torchsim]" + + - name: Log installed environment alongside torchsim + run: | + python3 -m pip freeze + + - name: Run torchsim tests + run: | + pytest -v tests/test_torch_sim.py diff --git a/mace/calculators/__init__.py b/mace/calculators/__init__.py index d60523766..69bc0bcc6 100644 --- a/mace/calculators/__init__.py +++ b/mace/calculators/__init__.py @@ -1,9 +1,11 @@ from .foundations_models import mace_anicc, mace_mp, mace_off, mace_omol from .lammps_mace import LAMMPS_MACE from .mace import MACECalculator +from .torch_sim import MaceTorchSimModel __all__ = [ "MACECalculator", + "MaceTorchSimModel", "LAMMPS_MACE", "mace_mp", "mace_off", diff --git a/mace/calculators/torch_sim.py b/mace/calculators/torch_sim.py new file mode 100644 index 000000000..8ff0fc657 --- /dev/null +++ b/mace/calculators/torch_sim.py @@ -0,0 +1,255 @@ +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from mace.tools import atomic_numbers_to_indices, utils + +try: + import torch_sim as ts + from torch_sim.models.interface import ModelInterface + from torch_sim.neighbors import torchsim_nl + + _TORCHSIM_IMPORT_ERROR: Optional[ImportError] = None +except ImportError as exc: + ts = None # type: ignore[assignment] + torchsim_nl = None # type: ignore[assignment] + _TORCHSIM_IMPORT_ERROR = exc + + class ModelInterface(torch.nn.Module): # type: ignore[no-redef] # pylint: disable=abstract-method + """Fallback base class when torch-sim is not installed.""" + + +def to_one_hot( + indices: torch.Tensor, num_classes: int, dtype: torch.dtype +) -> torch.Tensor: + """Generate one-hot vectors from class indices.""" + shape = indices.shape[:-1] + (num_classes,) + out = torch.zeros(shape, device=indices.device, dtype=dtype).view(shape) + out.scatter_(dim=-1, index=indices, value=1) + return out.view(*shape) + + +class MaceTorchSimModel(ModelInterface): + """TorchSim wrapper around a MACE model.""" + + def __init__( + self, + model: Union[str, Path, torch.nn.Module], + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float64, + neighbor_list_fn: Optional[Callable] = None, + compute_forces: bool = True, + compute_stress: bool = True, + enable_cueq: bool = False, + atomic_numbers: Optional[torch.Tensor] = None, + system_idx: Optional[torch.Tensor] = None, + ) -> None: + if _TORCHSIM_IMPORT_ERROR is not None: + raise ImportError( + "MaceTorchSimModel requires torch-sim-atomistic. " + "Install with `pip install torch-sim-atomistic` " + "or `pip install -e '.[torchsim]'`." + ) from _TORCHSIM_IMPORT_ERROR + + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + self.neighbor_list_fn = neighbor_list_fn or torchsim_nl + + if isinstance(model, (str, Path)): + self.model = torch.load(str(model), map_location=self.device) + elif isinstance(model, torch.nn.Module): + self.model = model.to(self.device) + else: + raise TypeError("model must be a path or torch.nn.Module") + + self.model = self.model.eval().to(device=self._device) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) + + if enable_cueq: + try: + from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq + except ImportError as exc: + raise ImportError( + "cuequivariance is not installed so CuEq acceleration cannot be used" + ) from exc + self.model = run_e3nn_to_cueq(self.model, device=self.device.type) + + self.r_max = self.model.r_max + self.z_table = utils.AtomicNumberTable( + [int(z) for z in self.model.atomic_numbers] + ) + self.model.atomic_numbers = ( + self.model.atomic_numbers.detach().clone().to(device=self.device) + ) + + self.atomic_numbers_in_init = atomic_numbers is not None + if atomic_numbers is not None: + if system_idx is None: + system_idx = torch.zeros( + len(atomic_numbers), dtype=torch.long, device=self.device + ) + self.setup_from_system_idx(atomic_numbers, system_idx) + + def setup_from_system_idx( + self, atomic_numbers: torch.Tensor, system_idx: torch.Tensor + ) -> None: + """Prepare cached batch tensors from atom-wise system assignment.""" + if atomic_numbers.shape[0] != system_idx.shape[0]: + raise ValueError("atomic_numbers and system_idx must have same shape[0]") + + self.atomic_numbers = atomic_numbers.to(device=self.device, dtype=torch.long) + self.system_idx = system_idx.to(device=self.device, dtype=torch.long) + + if self.system_idx.numel() == 0: + raise ValueError("at least one atom is required") + + self.n_systems = int(self.system_idx.max().item()) + 1 + self.n_atoms_per_system = [] + ptr = [0] + for idx in range(self.n_systems): + n_atoms = int((self.system_idx == idx).sum().item()) + self.n_atoms_per_system.append(n_atoms) + ptr.append(ptr[-1] + n_atoms) + + self.ptr = torch.tensor(ptr, dtype=torch.long, device=self.device) + self.total_atoms = int(self.atomic_numbers.shape[0]) + + atomic_indices = torch.tensor( + atomic_numbers_to_indices( + self.atomic_numbers.detach().cpu().numpy(), z_table=self.z_table + ), + dtype=torch.long, + device=self.device, + ).unsqueeze(-1) + + self.node_attrs = to_one_hot( + atomic_indices, + num_classes=len(self.z_table), + dtype=self.dtype, + ) + + def forward(self, state: Any) -> Dict[str, torch.Tensor]: + """Compute energies, forces and stresses for one or more systems.""" + if ts is None: + raise RuntimeError( + "torch-sim is required to call MaceTorchSimModel.forward" + ) + + if isinstance(state, ts.SimState): + sim_state = state.clone() + else: + state_dict = dict(state) + if "masses" not in state_dict: + state_dict["masses"] = torch.ones_like(state_dict["positions"]) + sim_state = ts.SimState(**state_dict) + + if sim_state.device != self.device or sim_state.dtype != self.dtype: + sim_state = sim_state.to(self.device, self.dtype) + + state_atomic_numbers = getattr(sim_state, "atomic_numbers", None) + if state_atomic_numbers is None and not self.atomic_numbers_in_init: + raise ValueError( + "atomic_numbers must be provided in the constructor or in forward." + ) + + if state_atomic_numbers is not None and self.atomic_numbers_in_init: + if not torch.equal(state_atomic_numbers, self.atomic_numbers): + raise ValueError( + "atomic_numbers in state do not match constructor values." + ) + + if sim_state.system_idx is None: + if not hasattr(self, "system_idx"): + raise ValueError( + "system_idx must be provided if not set during initialization" + ) + sim_state.system_idx = self.system_idx + + if not self.atomic_numbers_in_init: + cached_atomic_numbers = getattr(self, "atomic_numbers", None) + cached_system_idx = getattr(self, "system_idx", None) + needs_setup = state_atomic_numbers is not None and ( + cached_atomic_numbers is None + or cached_system_idx is None + or not torch.equal(state_atomic_numbers, cached_atomic_numbers) + or not torch.equal(sim_state.system_idx, cached_system_idx) + ) + if needs_setup: + self.setup_from_system_idx(state_atomic_numbers, sim_state.system_idx) + + wrapped_positions = ( + ts.transforms.pbc_wrap_batched( + sim_state.positions, + sim_state.cell, + sim_state.system_idx, + sim_state.pbc, + ) + if sim_state.pbc.any() + else sim_state.positions + ) + + edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( + wrapped_positions, + sim_state.row_vector_cell, + sim_state.pbc, + self.r_max, + sim_state.system_idx, + ) + shifts = ts.transforms.compute_cell_shifts( + sim_state.row_vector_cell, unit_shifts, mapping_system + ) + + data_dict = { + "ptr": self.ptr, + "node_attrs": self.node_attrs, + "batch": sim_state.system_idx, + "pbc": sim_state.pbc, + "cell": sim_state.row_vector_cell, + "positions": wrapped_positions, + "edge_index": edge_index, + "unit_shifts": unit_shifts, + "shifts": shifts, + "total_charge": sim_state.charge, + "total_spin": sim_state.spin, + } + + out = self.model( + data_dict, + compute_force=self.compute_forces, + compute_stress=self.compute_stress, + ) + + n_systems = sim_state.n_systems + results: Dict[str, torch.Tensor] = {} + + energy = out.get("energy") + if energy is None: + results["energy"] = torch.zeros( + n_systems, device=self.device, dtype=self.dtype + ) + else: + results["energy"] = energy.detach() + + if self.compute_forces: + forces = out.get("forces") + if forces is None: + forces = torch.zeros_like(sim_state.positions) + results["forces"] = forces.detach() + + if self.compute_stress: + stress = out.get("stress") + if stress is None: + stress = torch.zeros( + n_systems, 3, 3, device=self.device, dtype=self.dtype + ) + results["stress"] = stress.detach() + + return results diff --git a/setup.cfg b/setup.cfg index 16dd8e9d8..4fa35a48c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,6 +61,8 @@ dev = pytest-benchmark pylint schedulefree = schedulefree +torchsim = + torch-sim-atomistic; python_version >= "3.12" cueq = cuequivariance-torch>=0.2.0 cueq-cuda-11 = cuequivariance-ops-torch-cu11>=0.2.0 cueq-cuda-12 = cuequivariance-ops-torch-cu12>=0.2.0 diff --git a/tests/test_torch_sim.py b/tests/test_torch_sim.py new file mode 100644 index 000000000..73be74c4d --- /dev/null +++ b/tests/test_torch_sim.py @@ -0,0 +1,147 @@ +import pytest +import torch + +from mace.calculators import mace_mp, mace_off + +try: + import torch_sim as ts + from torch_sim.models.interface import validate_model_outputs + from torch_sim.testing import ( + SIMSTATE_BULK_GENERATORS, + SIMSTATE_MOLECULE_GENERATORS, + assert_model_calculator_consistency, + ) +except (ImportError, ModuleNotFoundError): + pytest.skip("Skipping torch-sim tests due to ImportError", allow_module_level=True) + +from mace.calculators.torch_sim import MaceTorchSimModel + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 +MACE_MP_MODEL = "small-0b" +MACE_OFF_MODEL = "small" + + +def _to_dtype_name(dtype: torch.dtype) -> str: + if dtype == torch.float32: + return "float32" + if dtype == torch.float64: + return "float64" + raise ValueError(f"Unsupported dtype {dtype}") + + +@pytest.fixture(scope="module") +def raw_mace_mp_model(): + return mace_mp( + model=MACE_MP_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + return_raw_model=True, + ) + + +@pytest.fixture(scope="module") +def raw_mace_off_model(): + return mace_off( + model=MACE_OFF_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + return_raw_model=True, + ) + + +@pytest.fixture +def ase_mace_mp_calculator(): + return mace_mp( + model=MACE_MP_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + dispersion=False, + ) + + +@pytest.fixture +def ase_mace_off_calculator(): + return mace_off( + model=MACE_OFF_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + ) + + +@pytest.fixture +def ts_mace_mp_model(raw_mace_mp_model): + return MaceTorchSimModel( + model=raw_mace_mp_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def ts_mace_off_model(raw_mace_off_model): + return MaceTorchSimModel( + model=raw_mace_off_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=False, + ) + + +@pytest.mark.parametrize("sim_state_name", ("si_sim_state", "rattled_si_sim_state")) +def test_torch_sim_mace_mp_consistency( + sim_state_name, ts_mace_mp_model, ase_mace_mp_calculator +): + sim_state = SIMSTATE_BULK_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency( + model=ts_mace_mp_model, + calculator=ase_mace_mp_calculator, + sim_state=sim_state, + ) + + +@pytest.mark.parametrize("sim_state_name", ("benzene_sim_state",)) +def test_torch_sim_mace_off_consistency( + sim_state_name, ts_mace_off_model, ase_mace_off_calculator +): + sim_state = SIMSTATE_MOLECULE_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency( + model=ts_mace_off_model, + calculator=ase_mace_off_calculator, + sim_state=sim_state, + ) + + +@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) +def test_torch_sim_mace_dtype_smoke(raw_mace_mp_model, dtype: torch.dtype): + model = MaceTorchSimModel( + model=raw_mace_mp_model, + device=DEVICE, + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) + state = SIMSTATE_BULK_GENERATORS["si_sim_state"](DEVICE, dtype) + output = model(state) + + assert output["energy"].shape == (1,) + assert torch.is_floating_point(output["energy"]) + assert output["forces"].shape == state.positions.shape + assert torch.is_floating_point(output["forces"]) + assert output["stress"].shape == (1, 3, 3) + + +def test_torch_sim_mace_off_output_keys(ts_mace_off_model): + state = SIMSTATE_MOLECULE_GENERATORS["benzene_sim_state"](DEVICE, DTYPE) + output = ts_mace_off_model(state) + assert "energy" in output + assert "forces" in output + assert "stress" not in output + + +def test_torch_sim_mace_validate_outputs(ts_mace_mp_model): + validate_model_outputs(ts_mace_mp_model, DEVICE, DTYPE) From c6f3af75c60cafa964b25464ef70b4a9e8511519 Mon Sep 17 00:00:00 2001 From: abhijeetgangan Date: Sun, 8 Mar 2026 01:56:38 -0800 Subject: [PATCH 2/2] restructure --- mace/calculators/__init__.py | 2 +- mace/calculators/torch_sim.py | 255 ---------------------------------- tests/test_torch_sim.py | 147 -------------------- tests/test_torchsim.py | 168 ++++++++++++++++++++-- 4 files changed, 161 insertions(+), 411 deletions(-) delete mode 100644 mace/calculators/torch_sim.py delete mode 100644 tests/test_torch_sim.py diff --git a/mace/calculators/__init__.py b/mace/calculators/__init__.py index 68f4e113f..3315275ad 100644 --- a/mace/calculators/__init__.py +++ b/mace/calculators/__init__.py @@ -1,7 +1,7 @@ from .foundations_models import mace_anicc, mace_mp, mace_off, mace_omol, mace_polar from .lammps_mace import LAMMPS_MACE from .mace import MACECalculator -from .torch_sim import MaceTorchSimModel +from .mace_torchsim import MaceTorchSimModel __all__ = [ "MACECalculator", diff --git a/mace/calculators/torch_sim.py b/mace/calculators/torch_sim.py deleted file mode 100644 index 8ff0fc657..000000000 --- a/mace/calculators/torch_sim.py +++ /dev/null @@ -1,255 +0,0 @@ -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union - -import torch - -from mace.tools import atomic_numbers_to_indices, utils - -try: - import torch_sim as ts - from torch_sim.models.interface import ModelInterface - from torch_sim.neighbors import torchsim_nl - - _TORCHSIM_IMPORT_ERROR: Optional[ImportError] = None -except ImportError as exc: - ts = None # type: ignore[assignment] - torchsim_nl = None # type: ignore[assignment] - _TORCHSIM_IMPORT_ERROR = exc - - class ModelInterface(torch.nn.Module): # type: ignore[no-redef] # pylint: disable=abstract-method - """Fallback base class when torch-sim is not installed.""" - - -def to_one_hot( - indices: torch.Tensor, num_classes: int, dtype: torch.dtype -) -> torch.Tensor: - """Generate one-hot vectors from class indices.""" - shape = indices.shape[:-1] + (num_classes,) - out = torch.zeros(shape, device=indices.device, dtype=dtype).view(shape) - out.scatter_(dim=-1, index=indices, value=1) - return out.view(*shape) - - -class MaceTorchSimModel(ModelInterface): - """TorchSim wrapper around a MACE model.""" - - def __init__( - self, - model: Union[str, Path, torch.nn.Module], - device: Optional[torch.device] = None, - dtype: torch.dtype = torch.float64, - neighbor_list_fn: Optional[Callable] = None, - compute_forces: bool = True, - compute_stress: bool = True, - enable_cueq: bool = False, - atomic_numbers: Optional[torch.Tensor] = None, - system_idx: Optional[torch.Tensor] = None, - ) -> None: - if _TORCHSIM_IMPORT_ERROR is not None: - raise ImportError( - "MaceTorchSimModel requires torch-sim-atomistic. " - "Install with `pip install torch-sim-atomistic` " - "or `pip install -e '.[torchsim]'`." - ) from _TORCHSIM_IMPORT_ERROR - - super().__init__() - self._device = device or torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - self._dtype = dtype - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self._memory_scales_with = "n_atoms_x_density" - self.neighbor_list_fn = neighbor_list_fn or torchsim_nl - - if isinstance(model, (str, Path)): - self.model = torch.load(str(model), map_location=self.device) - elif isinstance(model, torch.nn.Module): - self.model = model.to(self.device) - else: - raise TypeError("model must be a path or torch.nn.Module") - - self.model = self.model.eval().to(device=self._device) - if self.dtype is not None: - self.model = self.model.to(dtype=self.dtype) - - if enable_cueq: - try: - from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq - except ImportError as exc: - raise ImportError( - "cuequivariance is not installed so CuEq acceleration cannot be used" - ) from exc - self.model = run_e3nn_to_cueq(self.model, device=self.device.type) - - self.r_max = self.model.r_max - self.z_table = utils.AtomicNumberTable( - [int(z) for z in self.model.atomic_numbers] - ) - self.model.atomic_numbers = ( - self.model.atomic_numbers.detach().clone().to(device=self.device) - ) - - self.atomic_numbers_in_init = atomic_numbers is not None - if atomic_numbers is not None: - if system_idx is None: - system_idx = torch.zeros( - len(atomic_numbers), dtype=torch.long, device=self.device - ) - self.setup_from_system_idx(atomic_numbers, system_idx) - - def setup_from_system_idx( - self, atomic_numbers: torch.Tensor, system_idx: torch.Tensor - ) -> None: - """Prepare cached batch tensors from atom-wise system assignment.""" - if atomic_numbers.shape[0] != system_idx.shape[0]: - raise ValueError("atomic_numbers and system_idx must have same shape[0]") - - self.atomic_numbers = atomic_numbers.to(device=self.device, dtype=torch.long) - self.system_idx = system_idx.to(device=self.device, dtype=torch.long) - - if self.system_idx.numel() == 0: - raise ValueError("at least one atom is required") - - self.n_systems = int(self.system_idx.max().item()) + 1 - self.n_atoms_per_system = [] - ptr = [0] - for idx in range(self.n_systems): - n_atoms = int((self.system_idx == idx).sum().item()) - self.n_atoms_per_system.append(n_atoms) - ptr.append(ptr[-1] + n_atoms) - - self.ptr = torch.tensor(ptr, dtype=torch.long, device=self.device) - self.total_atoms = int(self.atomic_numbers.shape[0]) - - atomic_indices = torch.tensor( - atomic_numbers_to_indices( - self.atomic_numbers.detach().cpu().numpy(), z_table=self.z_table - ), - dtype=torch.long, - device=self.device, - ).unsqueeze(-1) - - self.node_attrs = to_one_hot( - atomic_indices, - num_classes=len(self.z_table), - dtype=self.dtype, - ) - - def forward(self, state: Any) -> Dict[str, torch.Tensor]: - """Compute energies, forces and stresses for one or more systems.""" - if ts is None: - raise RuntimeError( - "torch-sim is required to call MaceTorchSimModel.forward" - ) - - if isinstance(state, ts.SimState): - sim_state = state.clone() - else: - state_dict = dict(state) - if "masses" not in state_dict: - state_dict["masses"] = torch.ones_like(state_dict["positions"]) - sim_state = ts.SimState(**state_dict) - - if sim_state.device != self.device or sim_state.dtype != self.dtype: - sim_state = sim_state.to(self.device, self.dtype) - - state_atomic_numbers = getattr(sim_state, "atomic_numbers", None) - if state_atomic_numbers is None and not self.atomic_numbers_in_init: - raise ValueError( - "atomic_numbers must be provided in the constructor or in forward." - ) - - if state_atomic_numbers is not None and self.atomic_numbers_in_init: - if not torch.equal(state_atomic_numbers, self.atomic_numbers): - raise ValueError( - "atomic_numbers in state do not match constructor values." - ) - - if sim_state.system_idx is None: - if not hasattr(self, "system_idx"): - raise ValueError( - "system_idx must be provided if not set during initialization" - ) - sim_state.system_idx = self.system_idx - - if not self.atomic_numbers_in_init: - cached_atomic_numbers = getattr(self, "atomic_numbers", None) - cached_system_idx = getattr(self, "system_idx", None) - needs_setup = state_atomic_numbers is not None and ( - cached_atomic_numbers is None - or cached_system_idx is None - or not torch.equal(state_atomic_numbers, cached_atomic_numbers) - or not torch.equal(sim_state.system_idx, cached_system_idx) - ) - if needs_setup: - self.setup_from_system_idx(state_atomic_numbers, sim_state.system_idx) - - wrapped_positions = ( - ts.transforms.pbc_wrap_batched( - sim_state.positions, - sim_state.cell, - sim_state.system_idx, - sim_state.pbc, - ) - if sim_state.pbc.any() - else sim_state.positions - ) - - edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - wrapped_positions, - sim_state.row_vector_cell, - sim_state.pbc, - self.r_max, - sim_state.system_idx, - ) - shifts = ts.transforms.compute_cell_shifts( - sim_state.row_vector_cell, unit_shifts, mapping_system - ) - - data_dict = { - "ptr": self.ptr, - "node_attrs": self.node_attrs, - "batch": sim_state.system_idx, - "pbc": sim_state.pbc, - "cell": sim_state.row_vector_cell, - "positions": wrapped_positions, - "edge_index": edge_index, - "unit_shifts": unit_shifts, - "shifts": shifts, - "total_charge": sim_state.charge, - "total_spin": sim_state.spin, - } - - out = self.model( - data_dict, - compute_force=self.compute_forces, - compute_stress=self.compute_stress, - ) - - n_systems = sim_state.n_systems - results: Dict[str, torch.Tensor] = {} - - energy = out.get("energy") - if energy is None: - results["energy"] = torch.zeros( - n_systems, device=self.device, dtype=self.dtype - ) - else: - results["energy"] = energy.detach() - - if self.compute_forces: - forces = out.get("forces") - if forces is None: - forces = torch.zeros_like(sim_state.positions) - results["forces"] = forces.detach() - - if self.compute_stress: - stress = out.get("stress") - if stress is None: - stress = torch.zeros( - n_systems, 3, 3, device=self.device, dtype=self.dtype - ) - results["stress"] = stress.detach() - - return results diff --git a/tests/test_torch_sim.py b/tests/test_torch_sim.py deleted file mode 100644 index 73be74c4d..000000000 --- a/tests/test_torch_sim.py +++ /dev/null @@ -1,147 +0,0 @@ -import pytest -import torch - -from mace.calculators import mace_mp, mace_off - -try: - import torch_sim as ts - from torch_sim.models.interface import validate_model_outputs - from torch_sim.testing import ( - SIMSTATE_BULK_GENERATORS, - SIMSTATE_MOLECULE_GENERATORS, - assert_model_calculator_consistency, - ) -except (ImportError, ModuleNotFoundError): - pytest.skip("Skipping torch-sim tests due to ImportError", allow_module_level=True) - -from mace.calculators.torch_sim import MaceTorchSimModel - - -DEVICE = torch.device("cpu") -DTYPE = torch.float64 -MACE_MP_MODEL = "small-0b" -MACE_OFF_MODEL = "small" - - -def _to_dtype_name(dtype: torch.dtype) -> str: - if dtype == torch.float32: - return "float32" - if dtype == torch.float64: - return "float64" - raise ValueError(f"Unsupported dtype {dtype}") - - -@pytest.fixture(scope="module") -def raw_mace_mp_model(): - return mace_mp( - model=MACE_MP_MODEL, - device=str(DEVICE), - default_dtype=_to_dtype_name(DTYPE), - return_raw_model=True, - ) - - -@pytest.fixture(scope="module") -def raw_mace_off_model(): - return mace_off( - model=MACE_OFF_MODEL, - device=str(DEVICE), - default_dtype=_to_dtype_name(DTYPE), - return_raw_model=True, - ) - - -@pytest.fixture -def ase_mace_mp_calculator(): - return mace_mp( - model=MACE_MP_MODEL, - device=str(DEVICE), - default_dtype=_to_dtype_name(DTYPE), - dispersion=False, - ) - - -@pytest.fixture -def ase_mace_off_calculator(): - return mace_off( - model=MACE_OFF_MODEL, - device=str(DEVICE), - default_dtype=_to_dtype_name(DTYPE), - ) - - -@pytest.fixture -def ts_mace_mp_model(raw_mace_mp_model): - return MaceTorchSimModel( - model=raw_mace_mp_model, - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - - -@pytest.fixture -def ts_mace_off_model(raw_mace_off_model): - return MaceTorchSimModel( - model=raw_mace_off_model, - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=False, - ) - - -@pytest.mark.parametrize("sim_state_name", ("si_sim_state", "rattled_si_sim_state")) -def test_torch_sim_mace_mp_consistency( - sim_state_name, ts_mace_mp_model, ase_mace_mp_calculator -): - sim_state = SIMSTATE_BULK_GENERATORS[sim_state_name](DEVICE, DTYPE) - assert_model_calculator_consistency( - model=ts_mace_mp_model, - calculator=ase_mace_mp_calculator, - sim_state=sim_state, - ) - - -@pytest.mark.parametrize("sim_state_name", ("benzene_sim_state",)) -def test_torch_sim_mace_off_consistency( - sim_state_name, ts_mace_off_model, ase_mace_off_calculator -): - sim_state = SIMSTATE_MOLECULE_GENERATORS[sim_state_name](DEVICE, DTYPE) - assert_model_calculator_consistency( - model=ts_mace_off_model, - calculator=ase_mace_off_calculator, - sim_state=sim_state, - ) - - -@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) -def test_torch_sim_mace_dtype_smoke(raw_mace_mp_model, dtype: torch.dtype): - model = MaceTorchSimModel( - model=raw_mace_mp_model, - device=DEVICE, - dtype=dtype, - compute_forces=True, - compute_stress=True, - ) - state = SIMSTATE_BULK_GENERATORS["si_sim_state"](DEVICE, dtype) - output = model(state) - - assert output["energy"].shape == (1,) - assert torch.is_floating_point(output["energy"]) - assert output["forces"].shape == state.positions.shape - assert torch.is_floating_point(output["forces"]) - assert output["stress"].shape == (1, 3, 3) - - -def test_torch_sim_mace_off_output_keys(ts_mace_off_model): - state = SIMSTATE_MOLECULE_GENERATORS["benzene_sim_state"](DEVICE, DTYPE) - output = ts_mace_off_model(state) - assert "energy" in output - assert "forces" in output - assert "stress" not in output - - -def test_torch_sim_mace_validate_outputs(ts_mace_mp_model): - validate_model_outputs(ts_mace_mp_model, DEVICE, DTYPE) diff --git a/tests/test_torchsim.py b/tests/test_torchsim.py index 68ec7b752..9ef0d9548 100644 --- a/tests/test_torchsim.py +++ b/tests/test_torchsim.py @@ -8,14 +8,19 @@ import numpy as np import pytest import torch -from ase import build + +from mace.calculators import mace_mp, mace_off try: import torch_sim as ts - - TORCHSIM_AVAILABLE = True -except ImportError: - TORCHSIM_AVAILABLE = False + from torch_sim.models.interface import validate_model_outputs + from torch_sim.testing import ( + SIMSTATE_BULK_GENERATORS, + SIMSTATE_MOLECULE_GENERATORS, + assert_model_calculator_consistency, + ) +except (ImportError, ModuleNotFoundError): + pytest.skip("Skipping torch-sim tests due to ImportError", allow_module_level=True) try: import cuequivariance as cue # noqa: F401 @@ -24,13 +29,95 @@ except ImportError: CUET_AVAILABLE = False -pytestmark = pytest.mark.skipif( - not TORCHSIM_AVAILABLE, reason="torch-sim not installed" -) +from mace.calculators.mace_torchsim import MaceTorchSimModel pytest_mace_dir = Path(__file__).parent.parent run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" +DEVICE = torch.device("cpu") +DTYPE = torch.float64 +MACE_MP_MODEL = "small-0b" +MACE_OFF_MODEL = "small" + + +def _to_dtype_name(dtype: torch.dtype) -> str: + if dtype == torch.float32: + return "float32" + if dtype == torch.float64: + return "float64" + raise ValueError(f"Unsupported dtype {dtype}") + + +# --------------------------------------------------------------------------- +# Fixtures for foundation-model tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def raw_mace_mp_model(): + return mace_mp( + model=MACE_MP_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + return_raw_model=True, + ) + + +@pytest.fixture(scope="module") +def raw_mace_off_model(): + return mace_off( + model=MACE_OFF_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + return_raw_model=True, + ) + + +@pytest.fixture +def ase_mace_mp_calculator(): + return mace_mp( + model=MACE_MP_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + dispersion=False, + ) + + +@pytest.fixture +def ase_mace_off_calculator(): + return mace_off( + model=MACE_OFF_MODEL, + device=str(DEVICE), + default_dtype=_to_dtype_name(DTYPE), + ) + + +@pytest.fixture +def ts_mace_mp_model(raw_mace_mp_model): + return MaceTorchSimModel( + model=raw_mace_mp_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def ts_mace_off_model(raw_mace_off_model): + return MaceTorchSimModel( + model=raw_mace_off_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=False, + ) + + +# --------------------------------------------------------------------------- +# Fixtures for locally-trained-model tests +# --------------------------------------------------------------------------- + @pytest.fixture(scope="module") def trained_model_path(tmp_path_factory): @@ -123,6 +210,71 @@ def water_atoms(): return atoms +# --------------------------------------------------------------------------- +# Foundation-model tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("sim_state_name", ("si_sim_state", "rattled_si_sim_state")) +def test_torch_sim_mace_mp_consistency( + sim_state_name, ts_mace_mp_model, ase_mace_mp_calculator +): + sim_state = SIMSTATE_BULK_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency( + model=ts_mace_mp_model, + calculator=ase_mace_mp_calculator, + sim_state=sim_state, + ) + + +@pytest.mark.parametrize("sim_state_name", ("benzene_sim_state",)) +def test_torch_sim_mace_off_consistency( + sim_state_name, ts_mace_off_model, ase_mace_off_calculator +): + sim_state = SIMSTATE_MOLECULE_GENERATORS[sim_state_name](DEVICE, DTYPE) + assert_model_calculator_consistency( + model=ts_mace_off_model, + calculator=ase_mace_off_calculator, + sim_state=sim_state, + ) + + +@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) +def test_torch_sim_mace_dtype_smoke(raw_mace_mp_model, dtype: torch.dtype): + model = MaceTorchSimModel( + model=raw_mace_mp_model, + device=DEVICE, + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) + state = SIMSTATE_BULK_GENERATORS["si_sim_state"](DEVICE, dtype) + output = model(state) + + assert output["energy"].shape == (1,) + assert torch.is_floating_point(output["energy"]) + assert output["forces"].shape == state.positions.shape + assert torch.is_floating_point(output["forces"]) + assert output["stress"].shape == (1, 3, 3) + + +def test_torch_sim_mace_off_output_keys(ts_mace_off_model): + state = SIMSTATE_MOLECULE_GENERATORS["benzene_sim_state"](DEVICE, DTYPE) + output = ts_mace_off_model(state) + assert "energy" in output + assert "forces" in output + assert "stress" not in output + + +def test_torch_sim_mace_validate_outputs(ts_mace_mp_model): + validate_model_outputs(ts_mace_mp_model, DEVICE, DTYPE) + + +# --------------------------------------------------------------------------- +# Locally-trained-model tests +# --------------------------------------------------------------------------- + + def test_torchsim_basic(trained_model_path, water_atoms): from mace.calculators.mace_torchsim import MaceTorchSimModel