From 87760098adc0c9e37c479bf7d525dc994430afde Mon Sep 17 00:00:00 2001 From: gshs12051 Date: Mon, 23 Mar 2026 19:05:08 +0900 Subject: [PATCH 1/3] Integrate atomic virial into ForceStressOutputFromEdge and add atomic virial unit regression tests --- sevenn/calculator.py | 25 ++++++++++- sevenn/main/sevenn_get_model.py | 3 +- sevenn/nn/force_output.py | 60 +++---------------------- sevenn/scripts/deploy.py | 12 +++-- tests/unit_tests/test_atomic_virial.py | 61 ++++++++++++++++++++++++++ 5 files changed, 97 insertions(+), 64 deletions(-) create mode 100644 tests/unit_tests/test_atomic_virial.py diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 714f990c..06c52163 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -40,6 +40,7 @@ def __init__( enable_cueq: Optional[bool] = False, enable_flash: Optional[bool] = False, enable_oeq: Optional[bool] = False, + atomic_virial: bool = False, sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info **kwargs, ) -> None: @@ -69,9 +70,13 @@ def __init__( if True, use OpenEquivariance to accelerate inference. sevennet_config: dict | None, default=None Not used, but can be used to carry meta information of this calculator + atomic_virial: bool, default=False + If True, request per-atom virial output (`stresses`) at runtime. """ super().__init__(**kwargs) self.sevennet_config = None + self.atomic_virial_requested = atomic_virial + self.atomic_virial_from_deploy = False if isinstance(model, pathlib.PurePath): model = str(model) @@ -131,6 +136,7 @@ def __init__( 'version': b'', 'dtype': b'', 'time': b'', + 'atomic_virial': b'', } model_loaded = torch.jit.load( model, _extra_files=extra_dict, map_location=self.device @@ -141,6 +147,9 @@ def __init__( sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) } self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) + self.atomic_virial_from_deploy = ( + extra_dict['atomic_virial'].decode('utf-8') == 'yes' + ) elif isinstance(model, AtomGraphSequential): if model.type_map is None: @@ -161,6 +170,12 @@ def __init__( self.sevennet_config = sevennet_config self.model = model_loaded + if isinstance(self.model, AtomGraphSequential): + force_output = self.model._modules.get('force_output') + if force_output is not None: + self.atomic_virial_from_deploy = self.atomic_virial_from_deploy or bool( + getattr(force_output, 'use_atomic_virial', False) + ) self.modal = None if isinstance(self.model, AtomGraphSequential): @@ -182,6 +197,7 @@ def __init__( 'energy', 'forces', 'stress', + 'stresses', 'energies', ] @@ -218,7 +234,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: virial = ( output[KEY.PRED_ATOMIC_VIRIAL].detach().cpu().numpy()[:num_atoms, :] ) - results[KEY.PRED_ATOMIC_VIRIAL] = virial + results['stresses'] = virial return results def calculate(self, atoms=None, properties=None, system_changes=all_changes): @@ -246,8 +262,13 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data = data.to_dict() del data['data_info'] + elif self.atomic_virial_requested: + force_output = self.model._modules.get('force_output') + if force_output is not None and hasattr(force_output, 'use_atomic_virial'): + setattr(force_output, 'use_atomic_virial', True) - self.results = self.output_to_results(self.model(data)) + output = self.model(data) + self.results = self.output_to_results(output) class SevenNetD3Calculator(SumCalculator): diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index e1e87018..24e5bdf8 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -68,7 +68,8 @@ def add_args(parser): '--atomic_virial', help=( 'Serial deploy only: append per-atom virial output ' - '(inferred_atomic_virial) to TorchScript.' + '(inferred_atomic_virial) to TorchScript. This marks model-side ' + 'atomic virial capability for downstream calculator usage.' ), action='store_true', ) diff --git a/sevenn/nn/force_output.py b/sevenn/nn/force_output.py index b3ad9308..fdbbcbb0 100644 --- a/sevenn/nn/force_output.py +++ b/sevenn/nn/force_output.py @@ -64,7 +64,6 @@ def __init__( data_key_force: str = KEY.PRED_FORCE, data_key_stress: str = KEY.PRED_STRESS, data_key_cell_volume: str = KEY.CELL_VOLUME, - retain_graph_for_second_grad: bool = False, ) -> None: super().__init__() @@ -74,7 +73,6 @@ def __init__( self.key_stress = data_key_stress self.key_cell_volume = data_key_cell_volume self._is_batch_data = True - self._retain_graph_for_second_grad = retain_graph_for_second_grad def get_grad_key(self) -> str: return self.key_pos @@ -91,7 +89,6 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: grad = torch.autograd.grad( energy, [pos_tensor, data['_strain']], - retain_graph=self._retain_graph_for_second_grad, create_graph=self.training, allow_unused=True, # materialize_grads=True, @@ -152,7 +149,9 @@ def __init__( data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_force: str = KEY.PRED_FORCE, data_key_stress: str = KEY.PRED_STRESS, + data_key_atomic_virial: str = KEY.PRED_ATOMIC_VIRIAL, data_key_cell_volume: str = KEY.CELL_VOLUME, + use_atomic_virial: bool = False, ) -> None: super().__init__() @@ -161,7 +160,9 @@ def __init__( self.key_energy = data_key_energy self.key_force = data_key_force self.key_stress = data_key_stress + self.key_atomic_virial = data_key_atomic_virial self.key_cell_volume = data_key_cell_volume + self.use_atomic_virial = use_atomic_virial self._is_batch_data = True def get_grad_key(self) -> str: @@ -209,6 +210,8 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) _edge_dst6 = broadcast(edge_idx[1], _virial, 0) _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') + if self.use_atomic_virial: + data[self.key_atomic_virial] = torch.neg(_s) if self._is_batch_data: batch = data[KEY.BATCH] # for deploy, must be defined first @@ -226,54 +229,3 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: return data - -@compile_mode('script') -class AtomicVirialOutput(nn.Module): - """Per-atom virial from edge forces (post-processing).""" - - def __init__( - self, - data_key_edge: str = KEY.EDGE_VEC, - data_key_edge_idx: str = KEY.EDGE_IDX, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_virial: str = KEY.PRED_ATOMIC_VIRIAL, - ) -> None: - super().__init__() - self.key_edge = data_key_edge - self.key_edge_idx = data_key_edge_idx - self.key_energy = data_key_energy - self.key_virial = data_key_virial - - def get_grad_key(self) -> str: - return self.key_edge - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - rij = data[self.key_edge] - energy = [(data[self.key_energy]).sum()] - - grad_list = torch.autograd.grad( - energy, - [rij], - retain_graph=True, - create_graph=self.training, - allow_unused=True, - ) - fij_opt = grad_list[0] - assert fij_opt is not None, 'No gradient for edge vectors' - fij = fij_opt - - diag = rij * fij - s12 = (rij[..., 0] * fij[..., 1]).unsqueeze(-1) - s23 = (rij[..., 1] * fij[..., 2]).unsqueeze(-1) - s31 = (rij[..., 2] * fij[..., 0]).unsqueeze(-1) - edge_virial = torch.cat([diag, s12, s23, s31], dim=-1) - - tot_num = data[KEY.NODE_FEATURE].shape[0] - atom_virial = torch.zeros( - tot_num, 6, dtype=edge_virial.dtype, device=edge_virial.device - ) - dst = broadcast(data[self.key_edge_idx][1], edge_virial, 0) - atom_virial.scatter_reduce_(0, dst, edge_virial, reduce='sum') - - data[self.key_virial] = -atom_virial - return data diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 0b80416a..66415287 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -22,7 +22,7 @@ def deploy( atomic_virial: bool = False, ) -> None: from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import AtomicVirialOutput, ForceStressOutput + from sevenn.nn.force_output import ForceStressOutput, ForceStressOutputFromEdge cp = load_checkpoint(checkpoint) @@ -37,9 +37,10 @@ def deploy( ) model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutput( - retain_graph_for_second_grad=atomic_virial, - ) + if atomic_virial: + grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) + else: + grad_module = ForceStressOutput() model.replace_module('force_output', grad_module) new_grad_key = grad_module.get_grad_key() model.key_grad = new_grad_key @@ -56,9 +57,6 @@ def deploy( model.set_is_batch_data(False) model.eval() - if atomic_virial: - model.add_module('atomic_virial', AtomicVirialOutput()) - model = e3nn.util.jit.script(model) model = torch.jit.freeze(model) diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py new file mode 100644 index 00000000..52d37bcf --- /dev/null +++ b/tests/unit_tests/test_atomic_virial.py @@ -0,0 +1,61 @@ +import numpy as np +from ase.build import bulk + +import sevenn._keys as KEY +from sevenn.scripts.deploy import deploy +from sevenn.calculator import SevenNetCalculator +from sevenn.util import pretrained_name_to_path + + +def _get_atoms_pbc(): + atoms = bulk("NaCl", "rocksalt", a=5.63) + atoms.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms + + +def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): + model_path = str(tmp_path / "7net_0_atomic_virial.pt") + deploy(pretrained_name_to_path("7net-0_11July2024"), model_path, atomic_virial=True) + + calc = SevenNetCalculator(model_path, file_type="torchscript") + atoms = _get_atoms_pbc() + atoms.calc = calc + _ = atoms.get_potential_energy() + + assert "stresses" in calc.results + atomic_virial = np.asarray(calc.results["stresses"]) + + assert atomic_virial.shape == (len(atoms), 6) + assert np.isfinite(atomic_virial).all() + assert np.any(np.abs(atomic_virial) > 0.0) + # Full 6-component per-atom reference check. + # Sort rows for deterministic comparison when atom-wise ordering changes. + atomic_virial_ref = np.array( + [ + [10.06461800, -0.07430478, -0.07463801, -0.38235345, -0.04390856, -0.47967120], + [13.55999100, 1.41123860, 1.41158500, -1.28466740, -0.17591128, -1.18766780], + ] + ) + atomic_virial_sorted = atomic_virial[np.argsort(atomic_virial[:, 0])] + atomic_virial_ref_sorted = atomic_virial_ref[np.argsort(atomic_virial_ref[:, 0])] + assert np.allclose(atomic_virial_sorted, atomic_virial_ref_sorted, atol=1e-4) + + # Internal model stress ordering is [xx, yy, zz, xy, yz, zx]. + # Calculator exposes ASE stress as -stress_internal with + # component order [xx, yy, zz, yz, zx, xy]. + virial_sum = atomic_virial.sum(axis=0) + virial_sum_ref = np.array([ + 23.62459886, + 1.33693361, + 1.33694608, + -1.66702306, + -0.21981908, + -1.66734152, + ]) + assert np.allclose(virial_sum, virial_sum_ref, atol=1e-4) + + stress_internal_from_virial = virial_sum / atoms.get_volume() + stress_ase_from_virial = -stress_internal_from_virial[[0, 1, 2, 4, 5, 3]] + + assert np.allclose(calc.results["stress"], stress_ase_from_virial, atol=1e-5) From ab29a551c615c4d519497d12b81dcc5ef93d2e5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 10:20:02 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sevenn/nn/force_output.py | 1 - sevenn/scripts/deploy.py | 5 ++++- tests/unit_tests/test_atomic_virial.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sevenn/nn/force_output.py b/sevenn/nn/force_output.py index fdbbcbb0..0ee2073b 100644 --- a/sevenn/nn/force_output.py +++ b/sevenn/nn/force_output.py @@ -228,4 +228,3 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) return data - diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 66415287..6383167f 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -22,7 +22,10 @@ def deploy( atomic_virial: bool = False, ) -> None: from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ForceStressOutput, ForceStressOutputFromEdge + from sevenn.nn.force_output import ( + ForceStressOutput, + ForceStressOutputFromEdge, + ) cp = load_checkpoint(checkpoint) diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py index 52d37bcf..18d90464 100644 --- a/tests/unit_tests/test_atomic_virial.py +++ b/tests/unit_tests/test_atomic_virial.py @@ -2,29 +2,29 @@ from ase.build import bulk import sevenn._keys as KEY -from sevenn.scripts.deploy import deploy from sevenn.calculator import SevenNetCalculator +from sevenn.scripts.deploy import deploy from sevenn.util import pretrained_name_to_path def _get_atoms_pbc(): - atoms = bulk("NaCl", "rocksalt", a=5.63) + atoms = bulk('NaCl', 'rocksalt', a=5.63) atoms.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) atoms.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) return atoms def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): - model_path = str(tmp_path / "7net_0_atomic_virial.pt") - deploy(pretrained_name_to_path("7net-0_11July2024"), model_path, atomic_virial=True) + model_path = str(tmp_path / '7net_0_atomic_virial.pt') + deploy(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) - calc = SevenNetCalculator(model_path, file_type="torchscript") + calc = SevenNetCalculator(model_path, file_type='torchscript') atoms = _get_atoms_pbc() atoms.calc = calc _ = atoms.get_potential_energy() - assert "stresses" in calc.results - atomic_virial = np.asarray(calc.results["stresses"]) + assert 'stresses' in calc.results + atomic_virial = np.asarray(calc.results['stresses']) assert atomic_virial.shape == (len(atoms), 6) assert np.isfinite(atomic_virial).all() @@ -58,4 +58,4 @@ def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): stress_internal_from_virial = virial_sum / atoms.get_volume() stress_ase_from_virial = -stress_internal_from_virial[[0, 1, 2, 4, 5, 3]] - assert np.allclose(calc.results["stress"], stress_ase_from_virial, atol=1e-5) + assert np.allclose(calc.results['stress'], stress_ase_from_virial, atol=1e-5) From 967b4bda049907b4f704146248c1330af1dda94a Mon Sep 17 00:00:00 2001 From: LEE EUI TAE Date: Wed, 25 Mar 2026 20:24:18 +0900 Subject: [PATCH 3/3] pair_e3gnn.cpp uses pair-wise force --- sevenn/main/sevenn_get_model.py | 19 +- sevenn/pair_e3gnn/pair_e3gnn.cpp | 228 +++++++++++---------- sevenn/scripts/deploy.py | 93 +++++++-- tests/lammps_tests/scripts/stress_skel.lmp | 21 ++ tests/lammps_tests/test_lammps.py | 146 ++++++++++++- tests/unit_tests/test_atomic_virial.py | 4 +- tests/unit_tests/test_calculator.py | 4 +- 7 files changed, 376 insertions(+), 139 deletions(-) create mode 100644 tests/lammps_tests/scripts/stress_skel.lmp diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index 24e5bdf8..ac3193fd 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -114,12 +114,14 @@ def run(args): if use_mliap and get_parallel: raise ValueError('Currently, ML-IAP interface does not tested on parallel.') - if atomic_virial and not get_serial: - raise ValueError('--atomic_virial is only supported for serial deploy.') + if atomic_virial and (use_mliap or get_parallel): + raise ValueError('--atomic_virial is only supported for serial deployment.') # deploy if output_prefix is None: output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' + if atomic_virial: + output_prefix = 'deployed_model' if use_mliap: output_prefix += '_mliap' @@ -131,10 +133,10 @@ def run(args): checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) if not use_mliap: - from sevenn.scripts.deploy import deploy, deploy_parallel + from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts - if get_serial: - deploy( + if atomic_virial: + deploy_ts( checkpoint_path, output_prefix, modal, @@ -142,6 +144,13 @@ def run(args): use_oeq=use_oeq, atomic_virial=atomic_virial, ) + elif get_serial: + deploy( + checkpoint_path, + output_prefix, + modal, + use_flash=use_flash, + use_oeq=use_oeq) else: deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: diff --git a/sevenn/pair_e3gnn/pair_e3gnn.cpp b/sevenn/pair_e3gnn/pair_e3gnn.cpp index 3e99b80b..94942c6a 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn.cpp @@ -67,7 +67,7 @@ PairE3GNN::~PairE3GNN() { memory->destroy(setflag); memory->destroy(cutsq); memory->destroy(map); - memory->destroy(elements); + // memory->destroy(elements); } } @@ -77,42 +77,32 @@ void PairE3GNN::compute(int eflag, int vflag) { This compute function is ispired/modified from stress branch of pair-nequip https://github.com/mir-group/pair_nequip */ - if (eflag || vflag) ev_setup(eflag, vflag); else evflag = vflag_fdotr = 0; - if (vflag_atom) { - error->all(FLERR, "atomic stress is not supported\n"); - } - - int nlocal = list->inum; // same as nlocal - int *ilist = list->ilist; - tagint *tag = atom->tag; - std::unordered_map tag_map; + // if (vflag_atom) { + // error->all(FLERR, "atomic stress is not supported\n"); + // } if (atom->tag_consecutive() == 0) { - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag[i]; - tag_map[itag] = ii+1; - // printf("MODIFY setting %i => %i \n",itag, tag_map[itag] ); - } - } else { - //Ordered which mappling required - for (int ii = 0; ii < nlocal; ii++) { - const int itag = ilist[ii]+1; - tag_map[itag] = ii+1; - // printf("normal setting %i => %i \n",itag, tag_map[itag] ); - } + error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); } double **x = atom->x; double **f = atom->f; int *type = atom->type; - long num_atoms[1] = {nlocal}; + int nlocal = list->inum; // same as nlocal, Why? is it different from atom->nlocal? + int *ilist = list->ilist; + int inum = list->inum; - int tag2i[nlocal]; + // tag ignore PBC + tagint *tag = atom->tag; + + std::unordered_map tag_map; + std::vector graph_index_to_i(nlocal); + + long num_atoms[1] = {nlocal}; int *numneigh = list->numneigh; // j loop cond int **firstneigh = list->firstneigh; // j list @@ -125,78 +115,64 @@ void PairE3GNN::compute(int eflag, int vflag) { } const int nedges_upper_bound = bound; - float cell[3][3]; - cell[0][0] = domain->boxhi[0] - domain->boxlo[0]; - cell[0][1] = 0.0; - cell[0][2] = 0.0; - - cell[1][0] = domain->xy; - cell[1][1] = domain->boxhi[1] - domain->boxlo[1]; - cell[1][2] = 0.0; - - cell[2][0] = domain->xz; - cell[2][1] = domain->yz; - cell[2][2] = domain->boxhi[2] - domain->boxlo[2]; - - torch::Tensor inp_cell = torch::from_blob(cell, {3, 3}, FLOAT_TYPE); - torch::Tensor inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); - - torch::Tensor inp_node_type = torch::zeros({nlocal}, INTEGER_TYPE); - torch::Tensor inp_pos = torch::zeros({nlocal, 3}); - - torch::Tensor inp_cell_volume = - torch::dot(inp_cell[0], torch::cross(inp_cell[1], inp_cell[2], 0)); - - float pbc_shift_tmp[nedges_upper_bound][3]; - - auto node_type = inp_node_type.accessor(); - auto pos = inp_pos.accessor(); + std::vector node_type; + float edge_vec[nedges_upper_bound][3]; long edge_idx_src[nedges_upper_bound]; long edge_idx_dst[nedges_upper_bound]; - int nedges = 0; - - for (int ii = 0; ii < nlocal; ii++) { + for (int ii = 0; ii < inum; ii++) { + // populate tag_map of local atoms const int i = ilist[ii]; - int itag = tag_map[tag[i]]; - tag2i[itag - 1] = i; + const int itag = tag[i]; const int itype = type[i]; - node_type[itag - 1] = map[itype]; - pos[itag - 1][0] = x[i][0]; - pos[itag - 1][1] = x[i][1]; - pos[itag - 1][2] = x[i][2]; + tag_map[itag] = ii; + graph_index_to_i[ii] = i; + node_type.push_back(map[itype]); } - for (int ii = 0; ii < nlocal; ii++) { + int nedges = 0; + // loop over neighbors, build graph + for (int ii = 0; ii < inum; ii++) { const int i = ilist[ii]; - int itag = tag_map[tag[i]]; + const int i_graph_idx = ii; const int *jlist = firstneigh[i]; const int jnum = numneigh[i]; for (int jj = 0; jj < jnum; jj++) { - int j = jlist[jj]; // atom over pbc is different atom - int jtag = tag_map[tag[j]]; // atom over pbs is same atom (it starts from 1) + int j = jlist[jj]; + const int jtag = tag[j]; j &= NEIGHMASK; - const int jtype = type[j]; + const auto found = tag_map.find(jtag); + if (found == tag_map.end()) continue; + const int j_graph_idx = found->second; + + // we have to calculate Rij to check cutoff in lammps side const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], x[j][2] - x[i][2]}; const double Rij = delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; - if (Rij < cutoff_square) { - edge_idx_src[nedges] = itag - 1; - edge_idx_dst[nedges] = jtag - 1; - - pbc_shift_tmp[nedges][0] = x[j][0] - pos[jtag - 1][0]; - pbc_shift_tmp[nedges][1] = x[j][1] - pos[jtag - 1][1]; - pbc_shift_tmp[nedges][2] = x[j][2] - pos[jtag - 1][2]; + if (Rij < cutoff_square) { + // if given j is not inside cutoff + if (nedges >= nedges_upper_bound) { + error->all(FLERR, "nedges exceeded nedges_upper_bound"); + } + edge_idx_src[nedges] = i_graph_idx; + edge_idx_dst[nedges] = j_graph_idx; + edge_vec[nedges][0] = delij[0]; + edge_vec[nedges][1] = delij[1]; + edge_vec[nedges][2] = delij[2]; nedges++; } } // j loop end } // i loop end + // convert data to Tensor + auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE); + auto inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); + auto edge_idx_src_tensor = torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); auto edge_idx_dst_tensor = @@ -204,66 +180,102 @@ void PairE3GNN::compute(int eflag, int vflag) { auto inp_edge_index = torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); - // r' = r + {shift_tensor(integer vector of len 3)} @ cell_tensor - // shift_tensor = (cell_tensor)^-1^T @ (r' - r) - torch::Tensor cell_inv_tensor = - inp_cell.inverse().transpose(0, 1).unsqueeze(0).to(device); - torch::Tensor pbc_shift_tmp_tensor = - torch::from_blob(pbc_shift_tmp, {nedges, 3}, FLOAT_TYPE) - .view({nedges, 3, 1}) - .to(device); - torch::Tensor inp_cell_shift = - torch::bmm(cell_inv_tensor.expand({nedges, 3, 3}), pbc_shift_tmp_tensor) - .view({nedges, 3}); - - inp_pos.set_requires_grad(true); - - c10::Dict input_dict; + auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE); + if (print_info) { + std::cout << " Nlocal: " << nlocal << std::endl; + std::cout << " Nedges: " << nedges << "\n" << std::endl; + } + + auto edge_vec_device = inp_edge_vec.to(device); + edge_vec_device.set_requires_grad(true); + + torch::Dict input_dict; input_dict.insert("x", inp_node_type.to(device)); - input_dict.insert("pos", inp_pos.to(device)); input_dict.insert("edge_index", inp_edge_index.to(device)); + input_dict.insert("edge_vec", edge_vec_device); input_dict.insert("num_atoms", inp_num_atoms.to(device)); - input_dict.insert("cell_lattice_vectors", inp_cell.to(device)); - input_dict.insert("cell_volume", inp_cell_volume.to(device)); - input_dict.insert("pbc_shift", inp_cell_shift); + input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU)); std::vector input(1, input_dict); auto output = model.forward(input).toGenericDict(); - torch::Tensor total_energy_tensor = - output.at("inferred_total_energy").toTensor().cpu(); - torch::Tensor force_tensor = output.at("inferred_force").toTensor().cpu(); + torch::Tensor energy_tensor = + output.at("inferred_total_energy").toTensor().squeeze(); + + // dE_dr + auto grads = torch::autograd::grad({energy_tensor}, {edge_vec_device}); + torch::Tensor dE_dr = grads[0].to(torch::kCPU); + + eng_vdwl += energy_tensor.detach().to(torch::kCPU).item(); + torch::Tensor force_tensor = torch::zeros({nlocal, 3}); + + auto _edge_idx_src_tensor = + edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}); + auto _edge_idx_dst_tensor = + edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}); + + force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum"); + force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr), + "sum"); + auto forces = force_tensor.accessor(); - eng_vdwl += total_energy_tensor.item(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - f[i][0] += forces[itag][0]; - f[i][1] += forces[itag][1]; - f[i][2] += forces[itag][2]; + for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) { + int i = graph_index_to_i[graph_idx]; + f[i][0] += forces[graph_idx][0]; + f[i][1] += forces[graph_idx][1]; + f[i][2] += forces[graph_idx][2]; } + // Virial stress from edge contributions if (vflag) { - // more accurately, it is virial part of stress - torch::Tensor stress_tensor = output.at("inferred_stress").toTensor().cpu(); - auto virial_stress_tensor = stress_tensor * inp_cell_volume; - // xy yz zx order in vasp (voigt is xx yy zz yz xz xy) + auto diag = inp_edge_vec * dE_dr; + auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1); + auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2); + auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0); + std::vector voigt_list = { + diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)}; + auto voigt = torch::cat(voigt_list, 1); + + torch::Tensor per_atom_stress_tensor = torch::zeros({nlocal, 6}); + auto _edge_idx_dst6_tensor = + edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6}); + per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt, + "sum"); + + auto virial_stress_tensor = + torch::neg(torch::sum(per_atom_stress_tensor, 0)); auto virial_stress = virial_stress_tensor.accessor(); + virial[0] += virial_stress[0]; virial[1] += virial_stress[1]; virial[2] += virial_stress[2]; virial[3] += virial_stress[3]; virial[4] += virial_stress[5]; virial[5] += virial_stress[4]; + + if (vflag_atom) { + auto per_atom_stress = per_atom_stress_tensor.accessor(); + + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + vatom[i][0] += -per_atom_stress[gi][0]; + vatom[i][1] += -per_atom_stress[gi][1]; + vatom[i][2] += -per_atom_stress[gi][2]; + vatom[i][3] += -per_atom_stress[gi][3]; + vatom[i][4] += -per_atom_stress[gi][5]; + vatom[i][5] += -per_atom_stress[gi][4]; + } + } } if (eflag_atom) { torch::Tensor atomic_energy_tensor = - output.at("atomic_energy").toTensor().cpu().view({nlocal}); + output.at("atomic_energy").toTensor().to(torch::kCPU).view({nlocal}); auto atomic_energy = atomic_energy_tensor.accessor(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - eatom[i] += atomic_energy[itag]; + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + eatom[i] += atomic_energy[gi]; } } diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 6383167f..573ccc78 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -21,14 +21,7 @@ def deploy( use_oeq: bool = False, atomic_virial: bool = False, ) -> None: - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ( - ForceStressOutput, - ForceStressOutputFromEdge, - ) - cp = load_checkpoint(checkpoint) - model, config = ( cp.build_model( enable_cueq=False, @@ -39,14 +32,8 @@ def deploy( cp.config, ) - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - if atomic_virial: - grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) - else: - grad_module = ForceStressOutput() - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key + if 'force_output' in model._modules: + model.delete_module_by_key('force_output') if hasattr(model, 'eval_type_map'): setattr(model, 'eval_type_map', False) @@ -183,3 +170,79 @@ def deploy_parallel( model = torch.jit.freeze(model) torch.jit.save(model, fname_full, _extra_files=md_configs) + + +def deploy_ts( + checkpoint: Union[pathlib.Path, str], + fname='deployed_model.pt', + modal: Optional[str] = None, + use_flash: bool = False, + use_oeq: bool = False, + atomic_virial: bool = False, +) -> None: + ''' + only for SevenNetCalculator with torchscript input (not for e3gnn) + ''' + from sevenn.nn.edge_embedding import EdgePreprocess + from sevenn.nn.force_output import ( + ForceStressOutput, + ForceStressOutputFromEdge, + ) + + cp = load_checkpoint(checkpoint) + + model, config = ( + cp.build_model( + enable_cueq=False, + enable_flash=use_flash, + enable_oeq=use_oeq, + _flash_lammps=use_flash, + ), + cp.config, + ) + + model.prepand_module('edge_preprocess', EdgePreprocess(True)) + grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) + model.replace_module('force_output', grad_module) + new_grad_key = grad_module.get_grad_key() + model.key_grad = new_grad_key + + if hasattr(model, 'eval_type_map'): + setattr(model, 'eval_type_map', False) + + if modal: + model.prepare_modal_deploy(modal) + elif model.modal_map is not None and len(model.modal_map) >= 1: + raise ValueError( + f'Modal is not given. It has: {list(model.modal_map.keys())}' + ) + + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + # make some config need for md + md_configs = {} + type_map = config[KEY.TYPE_MAP] + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update({'flashTP': 'yes' if use_flash else 'no'}) + md_configs.update({'oeq': 'yes' if use_oeq else 'no'}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + md_configs.update({'atomic_virial': 'yes' if atomic_virial else 'no'}) + + if fname.endswith('.pt') is False: + fname += '.pt' + torch.jit.save(model, fname, _extra_files=md_configs) diff --git a/tests/lammps_tests/scripts/stress_skel.lmp b/tests/lammps_tests/scripts/stress_skel.lmp new file mode 100644 index 00000000..10b69e8c --- /dev/null +++ b/tests/lammps_tests/scripts/stress_skel.lmp @@ -0,0 +1,21 @@ + units metal + boundary __BOUNDARY__ + read_data __LMP_STCT__ + + mass * 1.0 # do not matter since we don't run MD + + pair_style __PAIR_STYLE__ + pair_coeff * * __POTENTIALS__ __ELEMENT__ + + timestep 0.002 + + compute pa all pe/atom + compute astress all stress/atom NULL pair + + thermo 1 + fix 1 all nve + thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp + dump mydump all custom 1 __FORCE_DUMP_PATH__ id type element c_pa c_astress[1] c_astress[2] c_astress[3] c_astress[4] c_astress[5] c_astress[6] x y z fx fy fz + dump_modify mydump sort id element __ELEMENT__ + + run 0 diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index 8847feea..502fd3e1 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -18,7 +18,7 @@ from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available from sevenn.nn.oeq_helper import is_oeq_available -from sevenn.scripts.deploy import deploy, deploy_parallel +from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts from sevenn.util import chemical_species_preprocess, pretrained_name_to_path logger = logging.getLogger('test_lammps') @@ -28,6 +28,9 @@ lmp_script_path = str( (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() ) +lmp_stress_script_path = str( + (pathlib.Path(__file__).parent / 'scripts' / 'stress_skel.lmp').resolve() +) data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O @@ -106,6 +109,22 @@ def ref_modal_calculator(): return SevenNetCalculator(cp_mf_path, modal='PBE') +@pytest.fixture(scope='module') +def ref_stress_calculator(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') + pot_path = str(tmp / 'deployed_atomic_virial.pt') + deploy_ts(cp_0_path, pot_path, atomic_virial=True) + return SevenNetCalculator(pot_path, file_type='torchscript') + + +@pytest.fixture(scope='module') +def ref_7net0_stress_calculator(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') + pot_path = str(tmp / 'deployed_atomic_virial.pt') + deploy_ts(cp_7net0_path, pot_path, atomic_virial=True) + return SevenNetCalculator(pot_path, file_type='torchscript') + + def get_model_config(): config = { 'cutoff': cutoff, @@ -176,7 +195,7 @@ def get_system(system_name, **kwargs): raise ValueError() -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6, check_atomic_stress=False): def acl(a, b, rtol=rtol, atol=atol): return np.allclose(a, b, rtol=rtol, atol=atol) @@ -190,6 +209,11 @@ def acl(a, b, rtol=rtol, atol=atol): rtol * 10, atol * 10, ) + if check_atomic_stress: + ref_atomic_virial = np.asarray(atoms1.calc.results['stresses']) + lmp_atomic_stress = np.asarray(atoms2.calc.results['atomic_stress']) + lmp_atomic_virial = -lmp_atomic_stress[:, [0, 1, 2, 3, 5, 4]] + assert acl(ref_atomic_virial, lmp_atomic_virial, rtol * 10, atol * 10) # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) @@ -219,7 +243,7 @@ def _lammps_results_to_atoms(lammps_log, force_dump): 'lmp_dump': force_dump, } # atomic energy read - latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] + latoms.calc.results['energies'] = np.ravel(latoms.arrays['c_pa']) stress = np.array( [ [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], @@ -230,10 +254,41 @@ def _lammps_results_to_atoms(lammps_log, force_dump): stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 latoms.calc.results['stress'] = stress + if 'c_astress[1]' in latoms.arrays: + atomic_stress = np.column_stack( + [ + np.asarray(latoms.arrays['c_astress[1]']), + np.asarray(latoms.arrays['c_astress[2]']), + np.asarray(latoms.arrays['c_astress[3]']), + np.asarray(latoms.arrays['c_astress[4]']), + np.asarray(latoms.arrays['c_astress[5]']), + np.asarray(latoms.arrays['c_astress[6]']), + ] + ) + latoms.calc.results['atomic_stress'] = atomic_stress / 1602.1766208 / 1000 + return latoms -def _run_lammps(atoms, pair_style, potential, wd, command, test_name): +def _run_lammps(atoms, pair_style, potential, wd, command, test_name, script_path): + def _rotate_stress(atomic_stress, rot_mat): + out = np.empty_like(atomic_stress) + for i, s in enumerate(atomic_stress): + sigma = np.array([ + [s[0], s[3], s[4]], + [s[3], s[1], s[5]], + [s[4], s[5], s[2]] + ]) + sigma = rot_mat @ sigma @ rot_mat.T + out[i] = [ + sigma[0, 0], + sigma[1, 1], + sigma[2, 2], + sigma[0, 1], + sigma[0, 2], + sigma[1, 2] + ] + return out wd = wd.resolve() pbc = atoms.get_pbc() pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) @@ -248,7 +303,7 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_stct, atoms, prismobj=prism, specorder=chem ) - with open(lmp_script_path, 'r') as f: + with open(script_path, 'r') as f: cont = f.read() lammps_log = str(wd / 'log.lammps') @@ -276,6 +331,10 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): rot_mat = prism.rot_mat results = copy.deepcopy(lmp_atoms.calc.results) + + # SinglePointCalculator does not know atomic_stress + at_stress = results.pop('atomic_stress', None) + r_force = np.dot(results['forces'], rot_mat.T) results['forces'] = r_force if 'stress' in results: @@ -287,19 +346,33 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_atoms.set_cell(r_cell, scale_atoms=True) lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() + if at_stress is not None: + lmp_atoms.calc.results['atomic_stress'] = _rotate_stress(at_stress, rot_mat) + return lmp_atoms def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): command = lammps_cmd - return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_script_path + ) def parallel_lammps_run( atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd ): command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' - return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn/parallel', potential, wd, command, test_name, lmp_script_path + ) + + +def serial_stress_lammps_run(atoms, potential, wd, test_name, lammps_cmd): + command = lammps_cmd + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_stress_script_path + ) def subprocess_routine(cmd, name): @@ -370,6 +443,65 @@ def test_serial_flash( assert_atoms(atoms, atoms_lammps, atol=1e-5) +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress( + system, serial_potential_path, ref_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path, + wd=tmp_path, + test_name='serial stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_oeq( + system, serial_potential_path_oeq, ref_7net0_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_oeq, + wd=tmp_path, + test_name='serial oeq stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_flash_available(), reason='flash not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_flash( + system, serial_potential_path_flash, ref_7net0_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_flash, + wd=tmp_path, + test_name='serial flash stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + @pytest.mark.parametrize( 'system,ncores', [ diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py index 18d90464..3bc56237 100644 --- a/tests/unit_tests/test_atomic_virial.py +++ b/tests/unit_tests/test_atomic_virial.py @@ -3,7 +3,7 @@ import sevenn._keys as KEY from sevenn.calculator import SevenNetCalculator -from sevenn.scripts.deploy import deploy +from sevenn.scripts.deploy import deploy_ts from sevenn.util import pretrained_name_to_path @@ -16,7 +16,7 @@ def _get_atoms_pbc(): def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): model_path = str(tmp_path / '7net_0_atomic_virial.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) + deploy_ts(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) calc = SevenNetCalculator(model_path, file_type='torchscript') atoms = _get_atoms_pbc() diff --git a/tests/unit_tests/test_calculator.py b/tests/unit_tests/test_calculator.py index 9b19308d..e4aacf2e 100644 --- a/tests/unit_tests/test_calculator.py +++ b/tests/unit_tests/test_calculator.py @@ -8,7 +8,7 @@ from sevenn.calculator import D3Calculator, SevenNetCalculator from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available -from sevenn.scripts.deploy import deploy +from sevenn.scripts.deploy import deploy_ts from sevenn.util import model_from_checkpoint, pretrained_name_to_path @@ -108,7 +108,7 @@ def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): atoms_pbc.rattle(stdev=0.01, seed=42) fname = str(tmp_path / '7net_0.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), fname) + deploy_ts(pretrained_name_to_path('7net-0_11July2024'), fname) calc_script = SevenNetCalculator(fname, file_type='torchscript') calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024'))