diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index 7bc7e347..dac2b65a 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1141,6 +1141,7 @@ static std::map KNOWN_QUANTITIES = { {"A/fs", 1e1}, {"m/s", 1e6}, {"nm/ps", 1e3}, + {"(eV/u)^(1/2)", 101.80506}, }, { // alternative names }}}, diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 84e5d4d9..54f0caa2 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -65,7 +65,7 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: "velocities": { "quantity": "velocity", "getter": ase.Atoms.get_velocities, - "unit": "nm/fs", + "unit": "(eV/u)^(1/2)", }, "charges": { "quantity": "charge", @@ -99,6 +99,10 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: }, } +IMPLEMENTED_PROPERTIES = [ + "heat_flux", +] + class MetatomicCalculator(ase.calculators.calculator.Calculator): """ @@ -318,7 +322,7 @@ def __init__( # We do our own check to verify if a property is implemented in `calculate()`, # so we pretend to be able to compute all properties ASE knows about. - self.implemented_properties = ALL_ASE_PROPERTIES + self.implemented_properties = ALL_ASE_PROPERTIES + IMPLEMENTED_PROPERTIES self.additional_outputs: Dict[str, TensorMap] = {} """ @@ -1019,8 +1023,7 @@ def _get_ase_input( tensor.set_info("quantity", infos["quantity"]) tensor.set_info("unit", infos["unit"]) - tensor.to(dtype=dtype, device=device) - return tensor + return tensor.to(dtype=dtype, device=device) def _ase_to_torch_data(atoms, dtype, device): diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py new file mode 100644 index 00000000..9a422532 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -0,0 +1,494 @@ +from typing import Dict, List, Optional + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from vesin.metatomic import compute_requested_neighbors + +from metatomic.torch import ( + AtomisticModel, + ModelEvaluationOptions, + ModelOutput, + System, +) + + +def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + Wrap positions into the periodic cell. + """ + fractional_positions = torch.einsum("iv,vk->ik", positions, cell.inverse()) + fractional_positions -= torch.floor(fractional_positions) + wrapped_positions = torch.einsum("iv,vk->ik", fractional_positions, cell) + + return wrapped_positions + + +def check_collisions( + cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Detect atoms that lie within a cutoff distance from the periodic cell boundaries, + i.e. have interactions with atoms at the opposite end of the cell. + """ + inv_cell = cell.inverse() + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + if heights.min() < (cutoff + skin): + raise ValueError( + "Cell is too small compared to (cutoff + skin) = " + + str(cutoff + skin) + + ". " + "Ensure that all cell vectors are at least this length. Currently, the" + " minimum cell vector length is " + str(heights.min()) + "." + ) + + cutoff += skin + normals = recip / norms[:, None] + norm_coords = torch.einsum("iv,kv->ik", positions, normals) + collisions = torch.hstack( + [norm_coords <= cutoff, norm_coords >= heights - cutoff], + ).to(device=positions.device) + + return ( + collisions[ + :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + ], + norm_coords, + ) + + +def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: + """ + Convert boundary-collision flags into a boolean mask over all periodic image + displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi + boundaries, we need the replicas at (1, 0, 0), (0, -1, 0), (1, -1, 0) image cells. + + collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + + returns: [N, 3, 3, 3] boolean mask over image displacements in {0, +1, -1}^3 + 0: no replica needed along that axis + 1: +1 replica needed along that axis (i.e., near low boundary, a replica is + placed just outside the high boundary) + 2: -1 replica needed along that axis (i.e., near high boundary, a replica is + placed just outside the low boundary) + axis order: x, y, z + """ + origin = torch.full( + (len(collisions),), True, dtype=torch.bool, device=collisions.device + ) + axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]]) + ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]]) + azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]]) + # leverage broadcasting + outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] + outs = torch.movedim(outs, -1, 0) + outs[:, 0, 0, 0] = False # not close to any boundary -> no replica needed + return outs.to(device=collisions.device) + + +def generate_replica_atoms( + types: torch.Tensor, + positions: torch.Tensor, + cell: torch.Tensor, + replicas: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For atoms near the low boundary (x_lo/y_lo/z_lo), generate their images shifted + by +1 cell vector (i.e., placed just outside the high boundary). + For atoms near the high boundary (x_hi/y_hi/z_hi), generate images shifted by −1 + cell vector. + """ + replicas = torch.argwhere(replicas) + replica_idx = replicas[:, 0] + replica_offsets = torch.tensor( + [0, 1, -1], device=positions.device, dtype=positions.dtype + )[replicas[:, 1:]] + replica_positions = positions[replica_idx] + replica_positions += torch.einsum("iA,Aa->ia", replica_offsets, cell) + + return replica_idx, types[replica_idx], replica_positions + + +def unfold_system(metatomic_system: System, cutoff: float, skin: float = 0.5) -> System: + """ + Unfold a periodic system by generating replica atoms for those near the cell + boundaries within the specified cutoff distance. + The unfolded system has no periodic boundary conditions. + """ + + wrapped_positions = wrap_positions( + metatomic_system.positions, metatomic_system.cell + ) + collisions, _ = check_collisions( + metatomic_system.cell, wrapped_positions, cutoff, skin + ) + replicas = collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = generate_replica_atoms( + metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas + ) + unfolded_types = torch.cat( + [ + metatomic_system.types, + replica_types, + ] + ) + unfolded_positions = torch.cat( + [ + wrapped_positions, + replica_positions, + ] + ) + unfolded_idx = torch.cat( + [ + torch.arange(len(metatomic_system.types), device=metatomic_system.device), + replica_idx, + ] + ) + unfolded_n_atoms = len(unfolded_types) + masses_block = metatomic_system.get_data("masses").block() + velocities_block = metatomic_system.get_data("velocities").block() + unfolded_masses = masses_block.values[unfolded_idx] + unfolded_velocities = velocities_block.values[unfolded_idx] + unfolded_masses_block = TensorBlock( + values=unfolded_masses, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=masses_block.components, + properties=masses_block.properties, + ) + unfolded_velocities_block = TensorBlock( + values=unfolded_velocities, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=velocities_block.components, + properties=velocities_block.properties, + ) + unfolded_system = System( + types=unfolded_types, + positions=unfolded_positions, + cell=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + dtype=unfolded_positions.dtype, + device=metatomic_system.device, + ), + pbc=torch.tensor([False, False, False], device=metatomic_system.device), + ) + unfolded_system.add_data( + "masses", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_masses_block], + ), + ) + unfolded_system.add_data( + "velocities", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_velocities_block], + ), + ) + return unfolded_system.to(metatomic_system.dtype, metatomic_system.device) + + +class HeatFluxWrapper(torch.nn.Module): + """ + A wrapper around an AtomisticModel that computes the heat flux of a system using the + unfolded system approach. The heat flux is computed using the atomic energies (eV), + positions(Å), masses(u), velocities(Å/fs), and the energy gradients. + + The unfolded system is generated by creating replica atoms for those near the cell + boundaries within the interaction range of the model wrapped. The wrapper adds the + heat flux to the model's outputs under the key "extra::heat_flux". + + For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux + for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` + """ + + def __init__(self, model: AtomisticModel, skin: float = 0.5): + """ + :param model: the :py:class:`AtomisticModel` to wrap, which should be able to + compute atomic energies and their gradients with respect to positions + :param skin: the skin parameter for unfolding the system. The wrapper will + generate replica atoms for those within (interaction_range + skin) distance from + the cell boundaries. A skin results in more replica atoms and thus higher + computational cost, but ensures that the heat flux is computed correctly. + """ + super().__init__() + + self._model = model + self.skin = skin + self._interaction_range = model.capabilities().interaction_range + + self._requested_inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + + hf_output = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + outputs = self._model.capabilities().outputs.copy() + outputs["extra::heat_flux"] = hf_output + self._model.capabilities().outputs["extra::heat_flux"] = hf_output + if outputs["energy"].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) + energies_output = ModelOutput( + quantity="energy", unit=outputs["energy"].unit, per_atom=True + ) + self._unfolded_run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs={"energy": energies_output}, + selected_atoms=None, + ) + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def barycenter_and_atomic_energies(self, system: System, n_atoms: int): + atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ + 0 + ].values.flatten() + total_e = atomic_e[:n_atoms].sum() + r_aux = system.positions.detach() + barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) + + return barycenter, atomic_e, total_e + + def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: + n_atoms = len(system.positions) + unfolded_system = unfold_system(system, self._interaction_range, self.skin).to( + "cpu" + ) + compute_requested_neighbors( + unfolded_system, self._unfolded_run_options.length_unit, model=self._model + ) + unfolded_system = unfolded_system.to(system.device) + velocities: torch.Tensor = ( + unfolded_system.get_data("velocities").block().values.reshape(-1, 3) + ) + masses: torch.Tensor = ( + unfolded_system.get_data("masses").block().values.reshape(-1) + ) + barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies( + unfolded_system, n_atoms + ) + + term1 = torch.zeros( + (3), device=system.positions.device, dtype=system.positions.dtype + ) + for i in range(3): + grad_i = torch.autograd.grad( + [barycenter[i]], + [unfolded_system.positions], + retain_graph=True, + create_graph=False, + )[0] + grad_i = torch.jit._unwrap_optional(grad_i) + term1[i] = (grad_i * velocities).sum() + + go = torch.jit.annotate( + Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] + ) + grads = torch.autograd.grad( + [total_e], + [unfolded_system.positions], + grad_outputs=go, + )[0] + grads = torch.jit._unwrap_optional(grads) + term2 = ( + unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True) + ).sum(dim=0) + + hf_pot = term1 - term2 + + hf_conv = ( + ( + atomic_e[:n_atoms] + + 0.5 + * masses[:n_atoms] + * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 + * 103.6427 # u*A^2/fs^2 to eV + )[:, None] + * velocities[:n_atoms] + ).sum(dim=0) + + return hf_pot + hf_conv + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs=outputs, + selected_atoms=None, + ) + results = self._model(systems, run_options, False) + + if "extra::heat_flux" not in outputs: + return results + + device = systems[0].device + heat_fluxes: List[torch.Tensor] = [] + for system in systems: + system.positions.requires_grad_(True) + heat_fluxes.append(self.calc_unfolded_heat_flux(system)) + + samples = Labels( + ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) + ) + + hf_block = TensorBlock( + values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), + samples=samples, + components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], + properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), + ) + results["extra::heat_flux"] = TensorMap( + Labels("_", torch.tensor([[0]], device=device)), [hf_block] + ) + return results + + +class HardyHeatFluxWrapper(torch.nn.Module): + """ + A wrapper around an AtomisticModel that computes the heat flux of a system using the + unfolded system approach. The heat flux is computed using the atomic energies (eV), + positions(Å), masses(u), velocities(Å/fs), and the energy gradients. + """ + + def __init__(self, model: AtomisticModel, skin: float = 0.5): + """ + :param model: the :py:class:`AtomisticModel` to wrap, which should be able to + compute atomic energies and their gradients with respect to positions + :param skin: the skin parameter for unfolding the system. The wrapper will + generate replica atoms for those within (interaction_range + skin) distance from + the cell boundaries. A skin results in more replica atoms and thus higher + computational cost, but ensures that the heat flux is computed correctly. + """ + super().__init__() + + self._model = model + self.skin = skin + self._interaction_range = model.capabilities().interaction_range + + self._requested_inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + + hf_output = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + outputs = self._model.capabilities().outputs.copy() + outputs["extra::heat_flux"] = hf_output + self._model.capabilities().outputs["extra::heat_flux"] = hf_output + if outputs["energy"].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) + energies_output = ModelOutput( + quantity="energy", unit=outputs["energy"].unit, per_atom=True + ) + self._unfolded_run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs={"energy": energies_output}, + selected_atoms=None, + ) + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def calc_hardy_heat_flux(self, system: System) -> torch.Tensor: + n_atoms = len(system.positions) + velocities: torch.Tensor = ( + system.get_data("velocities").block().values.reshape(-1, 3) + ) + masses: torch.Tensor = system.get_data("masses").block().values.reshape(-1) + atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ + 0 + ].values.flatten() + + hf_pot = torch.zeros( + 3, dtype=system.positions.dtype, device=system.positions.device + ) + for i, energy in enumerate(atomic_e): + grad = torch.autograd.grad( + [energy], + [system.positions], + retain_graph=True if i != len(atomic_e) - 1 else False, + )[0] + grad = torch.jit._unwrap_optional(grad) + hf_pot += ( + wrap_positions(system.positions[i] - system.positions, system.cell) + * (grad * velocities).sum(dim=1, keepdim=True) + ).sum(dim=0) + + hf_conv = ( + ( + atomic_e[:n_atoms] + + 0.5 + * masses[:n_atoms] + * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 + * 103.6427 # u*A^2/fs^2 to eV + )[:, None] + * velocities[:n_atoms] + ).sum(dim=0) + + return hf_pot + hf_conv + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs=outputs, + selected_atoms=None, + ) + results = self._model(systems, run_options, False) + + if "extra::heat_flux" not in outputs: + return results + + device = systems[0].device + heat_fluxes: List[torch.Tensor] = [] + for system in systems: + system.positions.requires_grad_(True) + heat_fluxes.append(self.calc_hardy_heat_flux(system)) + + samples = Labels( + ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) + ) + + hf_block = TensorBlock( + values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), + samples=samples, + components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], + properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), + ) + results["extra::heat_flux"] = TensorMap( + Labels("_", torch.tensor([[0]], device=device)), [hf_block] + ) + return results diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index f0ae7be7..6f69dabb 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -866,6 +866,6 @@ def test_additional_input(atoms): expected = ARRAY_QUANTITIES[name]["getter"](atoms).reshape(values.shape) if name == "velocities": - expected *= 10.0 # ase velocity is in nm/fs + expected /= ase.units.Ang / ase.units.fs # ase velocity is in (eV/u)^(1/2) and we want A/fs assert np.allclose(values, expected) diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py new file mode 100644 index 00000000..78c65d76 --- /dev/null +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -0,0 +1,497 @@ +import metatomic_lj_test +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + System, +) +from metatomic.torch.ase_calculator import MetatomicCalculator +from metatomic.torch.heat_flux import ( + HardyHeatFluxWrapper, + HeatFluxWrapper, + check_collisions, + collisions_to_replicas, + generate_replica_atoms, + unfold_system, + wrap_positions, +) + + +@pytest.fixture +def model(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="eV", + with_extension=False, + ) + + +@pytest.fixture +def atoms(): + cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]]) + positions = np.array([[3.0, 3.0, 3.0]]) + atoms = Atoms(f"Ar", scaled_positions=positions, cell=cell, pbc=True).repeat( + (2, 2, 2) + ) + MaxwellBoltzmannDistribution( + atoms, temperature_K=300, rng=np.random.default_rng(42) + ) + return atoms + + +def _make_scalar_tensormap(values: torch.Tensor, property_name: str) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels([property_name], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_velocity_tensormap(values: torch.Tensor) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[ + Labels( + ["xyz"], + torch.arange(3, device=values.device).reshape(-1, 1), + ) + ], + properties=Labels(["velocity"], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_system_with_data(positions: torch.Tensor, cell: torch.Tensor) -> System: + types = torch.tensor([1] * len(positions), dtype=torch.int32) + system = System( + types=types, + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + masses = torch.ones((len(positions), 1), dtype=positions.dtype) + velocities = torch.zeros((len(positions), 3, 1), dtype=positions.dtype) + system.add_data("masses", _make_scalar_tensormap(masses, "mass")) + system.add_data("velocities", _make_velocity_tensormap(velocities)) + return system + + +def test_wrap_positions_cubic_matches_expected(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[-0.1, 0.0, 0.0], [2.1, 1.0, -0.5]]) + wrapped = wrap_positions(positions, cell) + expected = torch.tensor([[1.9, 0.0, 0.0], [0.1, 1.0, 1.5]]) + assert torch.allclose(wrapped, expected) + + +def test_check_collisions_cubic_axis_order(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.9]]) + collisions, norm_coords = check_collisions(cell, positions, cutoff=0.2, skin=0.0) + assert torch.allclose(norm_coords, positions) + assert collisions.shape == (1, 6) + assert collisions[0].tolist() == [True, False, False, False, False, True] + + +def test_generate_replica_atoms_cubic_offsets(): + types = torch.tensor([1]) + positions = torch.tensor([[0.1, 1.0, 1.0]]) + cell = torch.eye(3) * 2.0 + collisions = torch.tensor([[True, False, False, False, False, False]]) + replicas = collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = generate_replica_atoms( + types, positions, cell, replicas + ) + assert replica_idx.tolist() == [0] + assert replica_types.tolist() == [1] + assert torch.allclose( + replica_positions, positions + torch.tensor([[2.0, 0.0, 0.0]]) + ) + + +def test_wrap_positions_triclinic_fractional_bounds_and_shift(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + positions = torch.tensor( + [ + [-0.1, 0.0, 0.0], + [2.1, 1.6, -0.5], + [4.2, -0.2, 6.1], + ] + ) + inv_cell = cell.inverse() + wrapped = wrap_positions(positions, cell) + fractional_before = torch.einsum("iv,vk->ik", positions, inv_cell) + fractional_after = torch.einsum("iv,vk->ik", wrapped, inv_cell) + + assert torch.all(fractional_after >= 0) + assert torch.all(fractional_after < 1) + + delta_frac = fractional_after - fractional_before + rounded = torch.round(delta_frac) + assert torch.allclose(delta_frac, rounded, atol=1e-6, rtol=0) + assert torch.allclose(rounded, -torch.floor(fractional_before), atol=1e-6, rtol=0) + + +def test_check_collisions_triclinic_targets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + cutoff = 0.2 + inv_cell = cell.inverse() + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + norm_vectors = recip / norms[:, None] + + target = torch.stack( + [ + torch.tensor([0.05, 0.6, 0.6]), + torch.tensor([heights[0] - 0.05, 0.05, heights[2] - 0.1]), + torch.tensor([0.3, heights[1] - 0.05, 0.1]), + ] + ) + positions = target @ torch.inverse(norm_vectors).T + + collisions, norm_coords = check_collisions(cell, positions, cutoff=cutoff, skin=0.0) + assert torch.allclose(norm_coords, target, atol=1e-6, rtol=0) + + expected_low = target <= cutoff + expected_high = target >= heights - cutoff + expected = torch.hstack([expected_low, expected_high]) + expected = expected[:, [0, 3, 1, 4, 2, 5]] + + assert torch.equal(collisions, expected) + + +def test_check_collisions_raises_on_small_cell(): + cell = torch.eye(3) * 1.0 + positions = torch.zeros((1, 3)) + with pytest.raises(ValueError, match="Cell is too small"): + check_collisions(cell, positions, cutoff=0.9, skin=0.2) + + +def test_collisions_to_replicas_combines_displacements(): + collisions = torch.tensor([[True, False, False, True, False, False]]) + replicas = collisions_to_replicas(collisions) + assert replicas.shape == (1, 3, 3, 3) + assert replicas[0, 0, 0, 0].item() is False + + nonzero = torch.nonzero(replicas, as_tuple=False) + expected = { + (0, 1, 0, 0), + (0, 0, 2, 0), + (0, 1, 2, 0), + } + assert {tuple(row.tolist()) for row in nonzero} == expected + + +def test_generate_replica_atoms_triclinic_offsets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + types = torch.tensor([1]) + positions = torch.tensor([[0.2, 0.4, 0.6]]) + collisions = torch.tensor([[True, False, True, False, True, False]]) + replicas = collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = generate_replica_atoms( + types, positions, cell, replicas + ) + + assert replica_idx.tolist() == [0, 0, 0, 0, 0, 0, 0] + assert replica_types.tolist() == [1, 1, 1, 1, 1, 1, 1] + + expected_offsets = [cell[0], cell[1], cell[2], cell[0] + cell[1], cell[0] + cell[2], cell[1] + cell[2], cell[0] + cell[1] + cell[2]] + expected_positions = [positions[0] + offset for offset in expected_offsets] + + for expected in expected_positions: + assert any( + torch.allclose(expected, actual, atol=1e-6, rtol=0) + for actual in replica_positions + ) + + +def test_unfold_system_adds_replica_and_data(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.0]]) + system = _make_system_with_data(positions, cell) + unfolded = unfold_system(system, cutoff=0.1) + + assert len(unfolded.positions) == 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + + masses = unfolded.get_data("masses").block().values + velocities = unfolded.get_data("velocities").block().values + assert masses.shape[0] == 2 + assert velocities.shape[0] == 2 + + assert torch.allclose(unfolded.positions[0], positions[0]) + assert torch.allclose( + unfolded.positions[1], positions[0] + torch.tensor([2.0, 0.0, 0.0]) + ) + + +def test_heat_flux_wrapper_requested_inputs(): + class DummyCapabilities: + def __init__(self): + self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + self.length_unit = "A" + self.interaction_range = 1.0 + + class DummyModel: + def __init__(self): + self._capabilities = DummyCapabilities() + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + results = {} + if "energy" in options.outputs: + values = torch.zeros( + (len(systems), 1), dtype=systems[0].positions.dtype + ) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), + ) + results["energy"] = TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + return results + + wrapper = HeatFluxWrapper(DummyModel()) + requested = wrapper.requested_inputs() + assert set(requested.keys()) == {"masses", "velocities"} + + +def test_unfolded_energy_order_used_for_barycenter(): + class DummyCapabilities: + def __init__(self): + self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + self.length_unit = "A" + self.interaction_range = 1.0 + + class DummyModel: + def __init__(self): + self._capabilities = DummyCapabilities() + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + system = systems[0] + n_atoms = len(system.positions) + values = torch.arange( + n_atoms, dtype=system.positions.dtype, device=system.positions.device + ).reshape(-1, 1) + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(n_atoms, device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + cell = torch.eye(3) * 10.0 + positions = torch.tensor( + [ + [0.05, 5.0, 5.0], # near x_lo -> one replica + [9.95, 5.5, 5.0], # near x_hi -> one replica + [0.05, 6.0, 5.5], # near x_lo -> one replica + ] + ) + system = _make_system_with_data(positions, cell) + unfolded = unfold_system(system, cutoff=0.1, skin=0.0) + n_atoms = len(system.positions) + assert len(unfolded.positions) == n_atoms * 2 + + wrapper = HeatFluxWrapper(DummyModel()) + barycenter, atomic_e, total_e = wrapper.barycenter_and_atomic_energies( + unfolded, n_atoms + ) + + expected_atomic_e = torch.arange( + len(unfolded.positions), + dtype=unfolded.positions.dtype, + device=unfolded.positions.device, + ) + expected_total_e = expected_atomic_e[:n_atoms].sum() + expected_barycenter = torch.einsum( + "i,ik->k", expected_atomic_e[:n_atoms], unfolded.positions[:n_atoms] + ) + + assert torch.allclose(atomic_e, expected_atomic_e) + assert torch.allclose(total_e, expected_total_e) + assert torch.allclose(barycenter, expected_barycenter) + + +def test_heat_flux_wrapper_forward_adds_output(monkeypatch): + class DummyCapabilities: + def __init__(self): + self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + self.length_unit = "A" + self.interaction_range = 1.0 + + class DummyModel: + def __init__(self): + self._capabilities = DummyCapabilities() + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + def _fake_hf(self, system): + return torch.tensor( + [1.0, 2.0, 3.0], device=system.device, dtype=system.positions.dtype + ) + + wrapper = HeatFluxWrapper(DummyModel()) + monkeypatch.setattr(HeatFluxWrapper, "calc_unfolded_heat_flux", _fake_hf) + + cell = torch.eye(3) + systems = [ + System( + types=torch.tensor([1], dtype=torch.int32), + positions=torch.zeros((1, 3)), + cell=cell, + pbc=torch.tensor([True, True, True]), + ), + System( + types=torch.tensor([1], dtype=torch.int32), + positions=torch.ones((1, 3)), + cell=cell, + pbc=torch.tensor([True, True, True]), + ), + ] + + outputs = { + "energy": ModelOutput(quantity="energy", unit="eV"), + "extra::heat_flux": ModelOutput(quantity="heat_flux", unit=""), + } + results = wrapper.forward(systems, outputs, None) + assert "extra::heat_flux" in results + hf_block = results["extra::heat_flux"].block() + assert hf_block.values.shape == (2, 3, 1) + assert torch.allclose(hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2)) + + +@pytest.mark.parametrize( + "heat_flux,expected", + [ + # (HardyHeatFluxWrapper, [[4.0898e-05], [-3.1652e-04], [-2.1660e-04]]), + (HeatFluxWrapper, [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]]), + ], +) +def test_heat_flux_wrapper_calc_heat_flux(heat_flux, expected, model, atoms): + metadata = ModelMetadata() + wrapper = heat_flux(model.eval()) + cap = wrapper._model.capabilities() + outputs = cap.outputs.copy() + outputs["extra::heat_flux"] = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + + new_cap = ModelCapabilities( + outputs=outputs, + atomic_types=cap.atomic_types, + interaction_range=cap.interaction_range, + length_unit=cap.length_unit, + supported_devices=cap.supported_devices, + dtype=cap.dtype, + ) + heat_model = AtomisticModel(wrapper.eval(), metadata, capabilities=new_cap).to( + device="cpu" + ) + calc = MetatomicCalculator( + heat_model, + device="cpu", + additional_outputs={ + "extra::heat_flux": ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + }, + ) + atoms.calc = calc + atoms.get_potential_energy() + assert "extra::heat_flux" in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs["extra::heat_flux"].block().values + assert torch.allclose( + results, + torch.tensor(expected, dtype=results.dtype), + )