diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9528c6322..2e1859898 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -65,30 +65,49 @@ jobs: - { python: '3.12', resolution: lowest-direct } - { python: '3.13', resolution: highest } model: - - { name: fairchem, test_path: "tests/models/test_fairchem.py" } - - { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } - - { name: graphpes, test_path: "tests/models/test_graphpes.py" } - - { name: mace, test_path: "tests/models/test_mace.py" } - - { name: mace, test_path: "tests/test_elastic.py" } - - { name: mace, test_path: "tests/test_optimizers_vs_ase.py" } - - { name: mattersim, test_path: "tests/models/test_mattersim.py" } - - { name: metatomic, test_path: "tests/models/test_metatomic.py" } - - { name: nequip, test_path: "tests/models/test_nequip_framework.py" } - - { name: orb, test_path: "tests/models/test_orb.py" } - - { name: sevenn, test_path: "tests/models/test_sevennet.py" } + - { name: fairchem, test_path: 'tests/models/test_fairchem.py' } + - { + name: fairchem-legacy, + test_path: 'tests/models/test_fairchem_legacy.py', + } + - { name: graphpes, test_path: 'tests/models/test_graphpes.py' } + - { name: mace, test_path: 'tests/models/test_mace.py' } + - { name: mace, test_path: 'tests/test_elastic.py' } + - { name: mace, test_path: 'tests/test_optimizers_vs_ase.py' } + - { name: mattersim, test_path: 'tests/models/test_mattersim.py' } + - { name: metatomic, test_path: 'tests/models/test_metatomic.py' } + - { name: nequip, test_path: 'tests/models/test_nequip_framework.py' } + - { name: orb, test_path: 'tests/models/test_orb.py' } + - { name: sevenn, test_path: 'tests/models/test_sevennet.py' } exclude: - version: { python: '3.13', resolution: lowest-direct } - model: { name: orb, test_path: "tests/models/test_orb.py" } + model: { name: orb, test_path: 'tests/models/test_orb.py' } - version: { python: '3.13', resolution: highest } - model: { name: orb, test_path: "tests/models/test_orb.py" } + model: { name: orb, test_path: 'tests/models/test_orb.py' } - version: { python: '3.13', resolution: lowest-direct } - model: { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } + model: + { + name: fairchem-legacy, + test_path: 'tests/models/test_fairchem_legacy.py', + } - version: { python: '3.13', resolution: highest } - model: { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" } + model: + { + name: fairchem-legacy, + test_path: 'tests/models/test_fairchem_legacy.py', + } - version: { python: '3.13', resolution: lowest-direct } - model: { name: nequip, test_path: "tests/models/test_nequip_framework.py" } + model: + { + name: nequip, + test_path: 'tests/models/test_nequip_framework.py', + } - version: { python: '3.13', resolution: highest } - model: { name: nequip, test_path: "tests/models/test_nequip_framework.py" } + model: + { + name: nequip, + test_path: 'tests/models/test_nequip_framework.py', + } runs-on: ${{ matrix.os }} steps: @@ -129,15 +148,16 @@ jobs: uv pip install "h5py>=3.12.1" "numpy>=1.26,<3" "scipy<1.17.0" "tables>=3.10.2" "torch>=2" "tqdm>=4.67" --system uv pip install "ase>=3.26" "phonopy>=2.37.0" "psutil>=7.0.0" "pymatgen>=2025.6.14" "pytest-cov>=6" "pytest>=8" --resolution=${{ matrix.version.resolution }} --system - - name: Install torch_sim with model dependencies if: ${{ matrix.model.name != 'fairchem-legacy' }} run: | + # setuptools <82 provides pkg_resources needed by mattersim and fairchem (via torchtnt). + # setuptools 82+ removed pkg_resources. Remove pin once those packages migrate. # always use numpy>=2 with Python 3.13 if [ "${{ matrix.version.python }}" = "3.13" ]; then - uv pip install -e ".[test,${{ matrix.model.name }}]" "numpy>=2" --resolution=${{ matrix.version.resolution }} --system + uv pip install -e ".[test,${{ matrix.model.name }}]" "numpy>=2" "setuptools>=70,<82" --resolution=${{ matrix.version.resolution }} --system else - uv pip install -e ".[test,${{ matrix.model.name }}]" --resolution=${{ matrix.version.resolution }} --system + uv pip install -e ".[test,${{ matrix.model.name }}]" "setuptools>=70,<82" --resolution=${{ matrix.version.resolution }} --system fi - name: Run ${{ matrix.model.test_path }} tests diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index 7ddb8dc56..9ddeceffe 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -268,21 +268,40 @@ def test_cubic_forces_vanish(self) -> None: constraint.adjust_forces(state, forces) assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) - def test_large_deformation_raises(self) -> None: - """Deformation gradient > 0.25 raises RuntimeError.""" + def test_large_deformation_clamped(self) -> None: + """Per-step deformation > 0.25 is clamped rather than rejected.""" + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + orig_cell = state.cell.clone() + new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25 + constraint.adjust_cell(state, new_cell) + # Cell should have changed (not rejected) but less than requested + assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6) + # Per-step clamp limits single-step strain to 0.25 + identity = torch.eye(3, dtype=DTYPE) + ref_cell = constraint.reference_cells[0] + strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity + assert torch.abs(strain).max().item() <= 0.25 + 1e-6 + + def test_nan_deformation_raises(self) -> None: + """NaN in proposed cell raises RuntimeError instead of propagating.""" state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) new_cell = state.cell.clone() - new_cell[0] *= 1.5 - with pytest.raises(RuntimeError, match="deformation gradient"): + new_cell[0, 0, 0] = float("nan") + with pytest.raises(RuntimeError, match="singular or ill-conditioned"): constraint.adjust_cell(state, new_cell) def test_init_mismatched_lengths_raises(self) -> None: - """Mismatched rotations/symm_maps lengths raises ValueError.""" + """Mismatched rotations/symm_maps/reference_cells lengths raise ValueError.""" rots = [torch.eye(3).unsqueeze(0)] smaps = [torch.zeros(1, 1, dtype=torch.long), torch.zeros(1, 2, dtype=torch.long)] with pytest.raises(ValueError, match="length mismatch"): FixSymmetry(rots, smaps) + # reference_cells length must match n_systems + smaps_ok = [torch.zeros(1, 1, dtype=torch.long)] + with pytest.raises(ValueError, match="reference_cells length"): + FixSymmetry(rots, smaps_ok, reference_cells=[torch.eye(3), torch.eye(3)]) @pytest.mark.parametrize("method", ["adjust_positions", "adjust_cell"]) def test_adjust_skipped_when_disabled(self, method: str) -> None: @@ -640,3 +659,45 @@ def test_noisy_model_preserves_symmetry_with_constraint( ) assert result["initial_spacegroups"][0] == 229 assert result["final_spacegroups"][0] == 229 + + def test_cumulative_strain_clamp_direct(self) -> None: + """adjust_cell clamps deformation when cumulative strain exceeds limit. + + Directly tests the clamping mechanism by repeatedly applying small + cell deformations that individually pass the per-step check (< 0.25) + but cumulatively exceed max_cumulative_strain. Verifies: + 1. The cell doesn't drift beyond the strain envelope + 2. Symmetry is preserved after many small steps + """ + state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + constraint.max_cumulative_strain = 0.15 + assert constraint.reference_cells is not None + ref_cell = constraint.reference_cells[0].clone() + + # Apply 20 small deformations (each ~5% along one axis) + # Total would be ~100% without clamping, well over the 0.15 limit + identity = torch.eye(3, dtype=DTYPE) + for _ in range(20): + # Stretch c-axis by 5% (cubic symmetrization isotropizes this) + stretch = identity.clone() + stretch[2, 2] = 1.05 + new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0) + constraint.adjust_cell(state, new_cell) + state.cell = new_cell + + # Cumulative strain must be clamped to the limit + final_cell = state.row_vector_cell[0] + cumulative = torch.linalg.solve(ref_cell, final_cell) - identity + max_strain = torch.abs(cumulative).max().item() + assert max_strain <= constraint.max_cumulative_strain + 1e-6, ( + f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}" + ) + + # Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15 + # Verify it's actually being clamped (not just small steps) + assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low" + + # Symmetry should still be detectable + datasets = get_symmetry_datasets(state, symprec=SYMPREC) + assert datasets[0].number == SPACEGROUPS["fcc"] diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index c7c8d5e50..73b29ad2a 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -10,6 +10,7 @@ from __future__ import annotations +import math import warnings from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Self @@ -697,8 +698,10 @@ class FixSymmetry(SystemConstraint): rotations: list[torch.Tensor] symm_maps: list[torch.Tensor] + reference_cells: list[torch.Tensor] | None do_adjust_positions: bool do_adjust_cell: bool + max_cumulative_strain: float def __init__( self, @@ -708,6 +711,8 @@ def __init__( *, adjust_positions: bool = True, adjust_cell: bool = True, + reference_cells: list[torch.Tensor] | None = None, + max_cumulative_strain: float = 0.5, ) -> None: """Initialize FixSymmetry constraint. @@ -717,6 +722,11 @@ def __init__( system_idx: System indices (defaults to 0..n_systems-1). adjust_positions: Whether to symmetrize position displacements. adjust_cell: Whether to symmetrize cell/stress adjustments. + reference_cells: Initial refined cells (row vectors) per system for + cumulative strain tracking. If None, cumulative check is skipped. + max_cumulative_strain: Maximum allowed cumulative strain from the + reference cell. If exceeded, the cell update is clamped to + keep the structure within this strain envelope. """ n_systems = len(rotations) if len(symm_maps) != n_systems: @@ -731,12 +741,19 @@ def __init__( raise ValueError( f"system_idx length ({len(system_idx)}) != n_systems ({n_systems})" ) + if reference_cells is not None and len(reference_cells) != n_systems: + raise ValueError( + f"reference_cells length ({len(reference_cells)}) " + f"!= n_systems ({n_systems})" + ) super().__init__(system_idx=system_idx) self.rotations = rotations self.symm_maps = symm_maps + self.reference_cells = reference_cells self.do_adjust_positions = adjust_positions self.do_adjust_cell = adjust_cell + self.max_cumulative_strain = max_cumulative_strain @classmethod def from_state( @@ -770,7 +787,7 @@ def from_state( from torch_sim.symmetrize import prep_symmetry, refine_and_prep_symmetry - rotations, symm_maps = [], [] + rotations, symm_maps, reference_cells = [], [], [] cumsum = _cumsum_with_zero(state.n_atoms_per_system) for sys_idx in range(state.n_systems): @@ -793,6 +810,8 @@ def from_state( rotations.append(rots) symm_maps.append(smap) + # Store the refined cell as the reference for cumulative strain tracking + reference_cells.append(state.row_vector_cell[sys_idx].clone()) return cls( rotations, @@ -800,6 +819,7 @@ def from_state( system_idx=torch.arange(state.n_systems, device=state.device), adjust_positions=adjust_positions, adjust_cell=adjust_cell, + reference_cells=reference_cells, ) # === Symmetrization hooks === @@ -834,12 +854,17 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: Computes ``F = inv(cell) @ new_cell_row``, symmetrizes ``F - I`` as a rank-2 tensor, then reconstructs ``cell @ (sym(F-I) + I)``. + Also checks cumulative strain from the initial reference cell. If the + total deformation exceeds ``max_cumulative_strain``, the update is + clamped to prevent phase transitions that would break the symmetry + constraint (e.g. hexagonal → tetragonal cell collapse). + Args: state: Current simulation state. new_cell: Cell tensor (n_systems, 3, 3) in column vector convention. Raises: - RuntimeError: If deformation gradient > 0.25. + RuntimeError: If deformation gradient contains NaN or Inf. """ if not self.do_adjust_cell: return @@ -850,16 +875,37 @@ def adjust_cell(self, state: SimState, new_cell: torch.Tensor) -> None: for ci, si in enumerate(self.system_idx): cur_cell = state.row_vector_cell[si] new_row = new_cell[si].mT # column → row convention + + # Per-step deformation: clamp large steps to avoid ill-conditioned + # symmetrization while still making progress. The cumulative strain + # guard below is the real safety net against phase transitions. deform_delta = torch.linalg.solve(cur_cell, new_row) - identity max_delta = torch.abs(deform_delta).max().item() - if not (max_delta <= 0.25): # catches NaN via negated comparison + if not math.isfinite(max_delta): raise RuntimeError( - f"FixSymmetry: deformation gradient {max_delta:.4f} > 0.25 " - f"too large. Use smaller optimization steps." + f"FixSymmetry: deformation gradient is {max_delta}, " + f"cell may be singular or ill-conditioned." ) + if max_delta > 0.25: + deform_delta = deform_delta * (0.25 / max_delta) + + # Symmetrize the per-step deformation rots = self.rotations[ci].to(dtype=state.dtype) sym_delta = symmetrize_rank2(cur_cell, deform_delta, rots) - new_cell[si] = (cur_cell @ (sym_delta + identity)).mT + proposed_cell = cur_cell @ (sym_delta + identity) + + # Cumulative strain check against reference cell + if self.reference_cells is not None: + ref_cell = self.reference_cells[ci].to( + device=state.device, dtype=state.dtype + ) + cumulative_strain = torch.linalg.solve(ref_cell, proposed_cell) - identity + max_cumulative = torch.abs(cumulative_strain).max().item() + if max_cumulative > self.max_cumulative_strain: + scale = self.max_cumulative_strain / max_cumulative + proposed_cell = ref_cell @ (cumulative_strain * scale + identity) + + new_cell[si] = proposed_cell.mT # back to column convention def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: """Symmetrize a rank-1 tensor in-place for each constrained system.""" @@ -890,6 +936,8 @@ def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 self.system_idx + system_offset, adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=list(self.reference_cells) if self.reference_cells else None, + max_cumulative_strain=self.max_cumulative_strain, ) @classmethod @@ -900,21 +948,28 @@ def merge(cls, constraints: list[Self]) -> Self: if any( c.do_adjust_positions != constraints[0].do_adjust_positions or c.do_adjust_cell != constraints[0].do_adjust_cell + or c.max_cumulative_strain != constraints[0].max_cumulative_strain for c in constraints[1:] ): raise ValueError( "Cannot merge FixSymmetry constraints with different " - "adjust_positions/adjust_cell settings" + "adjust_positions/adjust_cell/max_cumulative_strain settings" ) rotations = [r for c in constraints for r in c.rotations] symm_maps = [s for c in constraints for s in c.symm_maps] system_idx = torch.cat([c.system_idx for c in constraints]) + # Merge reference cells if all constraints have them + ref_cells = None + if all(c.reference_cells is not None for c in constraints): + ref_cells = [rc for c in constraints for rc in c.reference_cells] return cls( rotations, symm_maps, system_idx=system_idx, adjust_positions=constraints[0].do_adjust_positions, adjust_cell=constraints[0].do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=constraints[0].max_cumulative_strain, ) def select_constraint( @@ -928,12 +983,19 @@ def select_constraint( if not mask.any(): return None local_idx = mask.nonzero(as_tuple=False).flatten().tolist() + ref_cells = ( + [self.reference_cells[idx] for idx in local_idx] + if self.reference_cells + else None + ) return type(self)( [self.rotations[idx] for idx in local_idx], [self.symm_maps[idx] for idx in local_idx], _mask_constraint_indices(self.system_idx[mask], system_mask), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, ) def select_sub_constraint( @@ -945,12 +1007,15 @@ def select_sub_constraint( if sys_idx not in self.system_idx: return None local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() + ref_cells = [self.reference_cells[local]] if self.reference_cells else None return type(self)( [self.rotations[local]], [self.symm_maps[local]], torch.tensor([0], device=self.system_idx.device), adjust_positions=self.do_adjust_positions, adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, ) def __repr__(self) -> str: