Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,49 @@ jobs:
- { python: '3.12', resolution: lowest-direct }
- { python: '3.13', resolution: highest }
model:
- { name: fairchem, test_path: "tests/models/test_fairchem.py" }
- { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" }
- { name: graphpes, test_path: "tests/models/test_graphpes.py" }
- { name: mace, test_path: "tests/models/test_mace.py" }
- { name: mace, test_path: "tests/test_elastic.py" }
- { name: mace, test_path: "tests/test_optimizers_vs_ase.py" }
- { name: mattersim, test_path: "tests/models/test_mattersim.py" }
- { name: metatomic, test_path: "tests/models/test_metatomic.py" }
- { name: nequip, test_path: "tests/models/test_nequip_framework.py" }
- { name: orb, test_path: "tests/models/test_orb.py" }
- { name: sevenn, test_path: "tests/models/test_sevennet.py" }
- { name: fairchem, test_path: 'tests/models/test_fairchem.py' }
- {
name: fairchem-legacy,
test_path: 'tests/models/test_fairchem_legacy.py',
}
- { name: graphpes, test_path: 'tests/models/test_graphpes.py' }
- { name: mace, test_path: 'tests/models/test_mace.py' }
- { name: mace, test_path: 'tests/test_elastic.py' }
- { name: mace, test_path: 'tests/test_optimizers_vs_ase.py' }
- { name: mattersim, test_path: 'tests/models/test_mattersim.py' }
- { name: metatomic, test_path: 'tests/models/test_metatomic.py' }
- { name: nequip, test_path: 'tests/models/test_nequip_framework.py' }
- { name: orb, test_path: 'tests/models/test_orb.py' }
- { name: sevenn, test_path: 'tests/models/test_sevennet.py' }
exclude:
- version: { python: '3.13', resolution: lowest-direct }
model: { name: orb, test_path: "tests/models/test_orb.py" }
model: { name: orb, test_path: 'tests/models/test_orb.py' }
- version: { python: '3.13', resolution: highest }
model: { name: orb, test_path: "tests/models/test_orb.py" }
model: { name: orb, test_path: 'tests/models/test_orb.py' }
- version: { python: '3.13', resolution: lowest-direct }
model: { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" }
model:
{
name: fairchem-legacy,
test_path: 'tests/models/test_fairchem_legacy.py',
}
- version: { python: '3.13', resolution: highest }
model: { name: fairchem-legacy, test_path: "tests/models/test_fairchem_legacy.py" }
model:
{
name: fairchem-legacy,
test_path: 'tests/models/test_fairchem_legacy.py',
}
- version: { python: '3.13', resolution: lowest-direct }
model: { name: nequip, test_path: "tests/models/test_nequip_framework.py" }
model:
{
name: nequip,
test_path: 'tests/models/test_nequip_framework.py',
}
- version: { python: '3.13', resolution: highest }
model: { name: nequip, test_path: "tests/models/test_nequip_framework.py" }
model:
{
name: nequip,
test_path: 'tests/models/test_nequip_framework.py',
}
runs-on: ${{ matrix.os }}

steps:
Expand Down Expand Up @@ -129,15 +148,16 @@ jobs:
uv pip install "h5py>=3.12.1" "numpy>=1.26,<3" "scipy<1.17.0" "tables>=3.10.2" "torch>=2" "tqdm>=4.67" --system
uv pip install "ase>=3.26" "phonopy>=2.37.0" "psutil>=7.0.0" "pymatgen>=2025.6.14" "pytest-cov>=6" "pytest>=8" --resolution=${{ matrix.version.resolution }} --system


- name: Install torch_sim with model dependencies
if: ${{ matrix.model.name != 'fairchem-legacy' }}
run: |
# setuptools <82 provides pkg_resources needed by mattersim and fairchem (via torchtnt).
# setuptools 82+ removed pkg_resources. Remove pin once those packages migrate.
# always use numpy>=2 with Python 3.13
if [ "${{ matrix.version.python }}" = "3.13" ]; then
uv pip install -e ".[test,${{ matrix.model.name }}]" "numpy>=2" --resolution=${{ matrix.version.resolution }} --system
uv pip install -e ".[test,${{ matrix.model.name }}]" "numpy>=2" "setuptools>=70,<82" --resolution=${{ matrix.version.resolution }} --system
else
uv pip install -e ".[test,${{ matrix.model.name }}]" --resolution=${{ matrix.version.resolution }} --system
uv pip install -e ".[test,${{ matrix.model.name }}]" "setuptools>=70,<82" --resolution=${{ matrix.version.resolution }} --system
fi

- name: Run ${{ matrix.model.test_path }} tests
Expand Down
71 changes: 66 additions & 5 deletions tests/test_fix_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,21 +268,40 @@ def test_cubic_forces_vanish(self) -> None:
constraint.adjust_forces(state, forces)
assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10)

def test_large_deformation_raises(self) -> None:
"""Deformation gradient > 0.25 raises RuntimeError."""
def test_large_deformation_clamped(self) -> None:
"""Per-step deformation > 0.25 is clamped rather than rejected."""
state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
orig_cell = state.cell.clone()
new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25
constraint.adjust_cell(state, new_cell)
# Cell should have changed (not rejected) but less than requested
assert not torch.allclose(new_cell, orig_cell * 1.5, atol=1e-6)
# Per-step clamp limits single-step strain to 0.25
identity = torch.eye(3, dtype=DTYPE)
ref_cell = constraint.reference_cells[0]
strain = torch.linalg.solve(ref_cell, new_cell[0].mT) - identity
assert torch.abs(strain).max().item() <= 0.25 + 1e-6

def test_nan_deformation_raises(self) -> None:
"""NaN in proposed cell raises RuntimeError instead of propagating."""
state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
new_cell = state.cell.clone()
new_cell[0] *= 1.5
with pytest.raises(RuntimeError, match="deformation gradient"):
new_cell[0, 0, 0] = float("nan")
with pytest.raises(RuntimeError, match="singular or ill-conditioned"):
constraint.adjust_cell(state, new_cell)

def test_init_mismatched_lengths_raises(self) -> None:
"""Mismatched rotations/symm_maps lengths raises ValueError."""
"""Mismatched rotations/symm_maps/reference_cells lengths raise ValueError."""
rots = [torch.eye(3).unsqueeze(0)]
smaps = [torch.zeros(1, 1, dtype=torch.long), torch.zeros(1, 2, dtype=torch.long)]
with pytest.raises(ValueError, match="length mismatch"):
FixSymmetry(rots, smaps)
# reference_cells length must match n_systems
smaps_ok = [torch.zeros(1, 1, dtype=torch.long)]
with pytest.raises(ValueError, match="reference_cells length"):
FixSymmetry(rots, smaps_ok, reference_cells=[torch.eye(3), torch.eye(3)])

@pytest.mark.parametrize("method", ["adjust_positions", "adjust_cell"])
def test_adjust_skipped_when_disabled(self, method: str) -> None:
Expand Down Expand Up @@ -640,3 +659,45 @@ def test_noisy_model_preserves_symmetry_with_constraint(
)
assert result["initial_spacegroups"][0] == 229
assert result["final_spacegroups"][0] == 229

def test_cumulative_strain_clamp_direct(self) -> None:
"""adjust_cell clamps deformation when cumulative strain exceeds limit.

Directly tests the clamping mechanism by repeatedly applying small
cell deformations that individually pass the per-step check (< 0.25)
but cumulatively exceed max_cumulative_strain. Verifies:
1. The cell doesn't drift beyond the strain envelope
2. Symmetry is preserved after many small steps
"""
state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
constraint.max_cumulative_strain = 0.15
assert constraint.reference_cells is not None
ref_cell = constraint.reference_cells[0].clone()

# Apply 20 small deformations (each ~5% along one axis)
# Total would be ~100% without clamping, well over the 0.15 limit
identity = torch.eye(3, dtype=DTYPE)
for _ in range(20):
# Stretch c-axis by 5% (cubic symmetrization isotropizes this)
stretch = identity.clone()
stretch[2, 2] = 1.05
new_cell = (state.row_vector_cell[0] @ stretch).mT.unsqueeze(0)
constraint.adjust_cell(state, new_cell)
state.cell = new_cell

# Cumulative strain must be clamped to the limit
final_cell = state.row_vector_cell[0]
cumulative = torch.linalg.solve(ref_cell, final_cell) - identity
max_strain = torch.abs(cumulative).max().item()
assert max_strain <= constraint.max_cumulative_strain + 1e-6, (
f"Strain {max_strain:.4f} exceeded {constraint.max_cumulative_strain}"
)

# Without clamping, 1.05^20 = 2.65x → strain ~1.65, far over 0.15
# Verify it's actually being clamped (not just small steps)
assert max_strain > 0.10, f"Strain {max_strain:.4f} suspiciously low"

# Symmetry should still be detectable
datasets = get_symmetry_datasets(state, symprec=SYMPREC)
assert datasets[0].number == SPACEGROUPS["fcc"]
Loading
Loading