From 87760098adc0c9e37c479bf7d525dc994430afde Mon Sep 17 00:00:00 2001 From: gshs12051 Date: Mon, 23 Mar 2026 19:05:08 +0900 Subject: [PATCH 1/7] Integrate atomic virial into ForceStressOutputFromEdge and add atomic virial unit regression tests --- sevenn/calculator.py | 25 ++++++++++- sevenn/main/sevenn_get_model.py | 3 +- sevenn/nn/force_output.py | 60 +++---------------------- sevenn/scripts/deploy.py | 12 +++-- tests/unit_tests/test_atomic_virial.py | 61 ++++++++++++++++++++++++++ 5 files changed, 97 insertions(+), 64 deletions(-) create mode 100644 tests/unit_tests/test_atomic_virial.py diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 714f990c..06c52163 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -40,6 +40,7 @@ def __init__( enable_cueq: Optional[bool] = False, enable_flash: Optional[bool] = False, enable_oeq: Optional[bool] = False, + atomic_virial: bool = False, sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info **kwargs, ) -> None: @@ -69,9 +70,13 @@ def __init__( if True, use OpenEquivariance to accelerate inference. sevennet_config: dict | None, default=None Not used, but can be used to carry meta information of this calculator + atomic_virial: bool, default=False + If True, request per-atom virial output (`stresses`) at runtime. """ super().__init__(**kwargs) self.sevennet_config = None + self.atomic_virial_requested = atomic_virial + self.atomic_virial_from_deploy = False if isinstance(model, pathlib.PurePath): model = str(model) @@ -131,6 +136,7 @@ def __init__( 'version': b'', 'dtype': b'', 'time': b'', + 'atomic_virial': b'', } model_loaded = torch.jit.load( model, _extra_files=extra_dict, map_location=self.device @@ -141,6 +147,9 @@ def __init__( sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) } self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) + self.atomic_virial_from_deploy = ( + extra_dict['atomic_virial'].decode('utf-8') == 'yes' + ) elif isinstance(model, AtomGraphSequential): if model.type_map is None: @@ -161,6 +170,12 @@ def __init__( self.sevennet_config = sevennet_config self.model = model_loaded + if isinstance(self.model, AtomGraphSequential): + force_output = self.model._modules.get('force_output') + if force_output is not None: + self.atomic_virial_from_deploy = self.atomic_virial_from_deploy or bool( + getattr(force_output, 'use_atomic_virial', False) + ) self.modal = None if isinstance(self.model, AtomGraphSequential): @@ -182,6 +197,7 @@ def __init__( 'energy', 'forces', 'stress', + 'stresses', 'energies', ] @@ -218,7 +234,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: virial = ( output[KEY.PRED_ATOMIC_VIRIAL].detach().cpu().numpy()[:num_atoms, :] ) - results[KEY.PRED_ATOMIC_VIRIAL] = virial + results['stresses'] = virial return results def calculate(self, atoms=None, properties=None, system_changes=all_changes): @@ -246,8 +262,13 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data = data.to_dict() del data['data_info'] + elif self.atomic_virial_requested: + force_output = self.model._modules.get('force_output') + if force_output is not None and hasattr(force_output, 'use_atomic_virial'): + setattr(force_output, 'use_atomic_virial', True) - self.results = self.output_to_results(self.model(data)) + output = self.model(data) + self.results = self.output_to_results(output) class SevenNetD3Calculator(SumCalculator): diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index e1e87018..24e5bdf8 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -68,7 +68,8 @@ def add_args(parser): '--atomic_virial', help=( 'Serial deploy only: append per-atom virial output ' - '(inferred_atomic_virial) to TorchScript.' + '(inferred_atomic_virial) to TorchScript. This marks model-side ' + 'atomic virial capability for downstream calculator usage.' ), action='store_true', ) diff --git a/sevenn/nn/force_output.py b/sevenn/nn/force_output.py index b3ad9308..fdbbcbb0 100644 --- a/sevenn/nn/force_output.py +++ b/sevenn/nn/force_output.py @@ -64,7 +64,6 @@ def __init__( data_key_force: str = KEY.PRED_FORCE, data_key_stress: str = KEY.PRED_STRESS, data_key_cell_volume: str = KEY.CELL_VOLUME, - retain_graph_for_second_grad: bool = False, ) -> None: super().__init__() @@ -74,7 +73,6 @@ def __init__( self.key_stress = data_key_stress self.key_cell_volume = data_key_cell_volume self._is_batch_data = True - self._retain_graph_for_second_grad = retain_graph_for_second_grad def get_grad_key(self) -> str: return self.key_pos @@ -91,7 +89,6 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: grad = torch.autograd.grad( energy, [pos_tensor, data['_strain']], - retain_graph=self._retain_graph_for_second_grad, create_graph=self.training, allow_unused=True, # materialize_grads=True, @@ -152,7 +149,9 @@ def __init__( data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_force: str = KEY.PRED_FORCE, data_key_stress: str = KEY.PRED_STRESS, + data_key_atomic_virial: str = KEY.PRED_ATOMIC_VIRIAL, data_key_cell_volume: str = KEY.CELL_VOLUME, + use_atomic_virial: bool = False, ) -> None: super().__init__() @@ -161,7 +160,9 @@ def __init__( self.key_energy = data_key_energy self.key_force = data_key_force self.key_stress = data_key_stress + self.key_atomic_virial = data_key_atomic_virial self.key_cell_volume = data_key_cell_volume + self.use_atomic_virial = use_atomic_virial self._is_batch_data = True def get_grad_key(self) -> str: @@ -209,6 +210,8 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) _edge_dst6 = broadcast(edge_idx[1], _virial, 0) _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') + if self.use_atomic_virial: + data[self.key_atomic_virial] = torch.neg(_s) if self._is_batch_data: batch = data[KEY.BATCH] # for deploy, must be defined first @@ -226,54 +229,3 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: return data - -@compile_mode('script') -class AtomicVirialOutput(nn.Module): - """Per-atom virial from edge forces (post-processing).""" - - def __init__( - self, - data_key_edge: str = KEY.EDGE_VEC, - data_key_edge_idx: str = KEY.EDGE_IDX, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_virial: str = KEY.PRED_ATOMIC_VIRIAL, - ) -> None: - super().__init__() - self.key_edge = data_key_edge - self.key_edge_idx = data_key_edge_idx - self.key_energy = data_key_energy - self.key_virial = data_key_virial - - def get_grad_key(self) -> str: - return self.key_edge - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - rij = data[self.key_edge] - energy = [(data[self.key_energy]).sum()] - - grad_list = torch.autograd.grad( - energy, - [rij], - retain_graph=True, - create_graph=self.training, - allow_unused=True, - ) - fij_opt = grad_list[0] - assert fij_opt is not None, 'No gradient for edge vectors' - fij = fij_opt - - diag = rij * fij - s12 = (rij[..., 0] * fij[..., 1]).unsqueeze(-1) - s23 = (rij[..., 1] * fij[..., 2]).unsqueeze(-1) - s31 = (rij[..., 2] * fij[..., 0]).unsqueeze(-1) - edge_virial = torch.cat([diag, s12, s23, s31], dim=-1) - - tot_num = data[KEY.NODE_FEATURE].shape[0] - atom_virial = torch.zeros( - tot_num, 6, dtype=edge_virial.dtype, device=edge_virial.device - ) - dst = broadcast(data[self.key_edge_idx][1], edge_virial, 0) - atom_virial.scatter_reduce_(0, dst, edge_virial, reduce='sum') - - data[self.key_virial] = -atom_virial - return data diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 0b80416a..66415287 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -22,7 +22,7 @@ def deploy( atomic_virial: bool = False, ) -> None: from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import AtomicVirialOutput, ForceStressOutput + from sevenn.nn.force_output import ForceStressOutput, ForceStressOutputFromEdge cp = load_checkpoint(checkpoint) @@ -37,9 +37,10 @@ def deploy( ) model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutput( - retain_graph_for_second_grad=atomic_virial, - ) + if atomic_virial: + grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) + else: + grad_module = ForceStressOutput() model.replace_module('force_output', grad_module) new_grad_key = grad_module.get_grad_key() model.key_grad = new_grad_key @@ -56,9 +57,6 @@ def deploy( model.set_is_batch_data(False) model.eval() - if atomic_virial: - model.add_module('atomic_virial', AtomicVirialOutput()) - model = e3nn.util.jit.script(model) model = torch.jit.freeze(model) diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py new file mode 100644 index 00000000..52d37bcf --- /dev/null +++ b/tests/unit_tests/test_atomic_virial.py @@ -0,0 +1,61 @@ +import numpy as np +from ase.build import bulk + +import sevenn._keys as KEY +from sevenn.scripts.deploy import deploy +from sevenn.calculator import SevenNetCalculator +from sevenn.util import pretrained_name_to_path + + +def _get_atoms_pbc(): + atoms = bulk("NaCl", "rocksalt", a=5.63) + atoms.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms + + +def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): + model_path = str(tmp_path / "7net_0_atomic_virial.pt") + deploy(pretrained_name_to_path("7net-0_11July2024"), model_path, atomic_virial=True) + + calc = SevenNetCalculator(model_path, file_type="torchscript") + atoms = _get_atoms_pbc() + atoms.calc = calc + _ = atoms.get_potential_energy() + + assert "stresses" in calc.results + atomic_virial = np.asarray(calc.results["stresses"]) + + assert atomic_virial.shape == (len(atoms), 6) + assert np.isfinite(atomic_virial).all() + assert np.any(np.abs(atomic_virial) > 0.0) + # Full 6-component per-atom reference check. + # Sort rows for deterministic comparison when atom-wise ordering changes. + atomic_virial_ref = np.array( + [ + [10.06461800, -0.07430478, -0.07463801, -0.38235345, -0.04390856, -0.47967120], + [13.55999100, 1.41123860, 1.41158500, -1.28466740, -0.17591128, -1.18766780], + ] + ) + atomic_virial_sorted = atomic_virial[np.argsort(atomic_virial[:, 0])] + atomic_virial_ref_sorted = atomic_virial_ref[np.argsort(atomic_virial_ref[:, 0])] + assert np.allclose(atomic_virial_sorted, atomic_virial_ref_sorted, atol=1e-4) + + # Internal model stress ordering is [xx, yy, zz, xy, yz, zx]. + # Calculator exposes ASE stress as -stress_internal with + # component order [xx, yy, zz, yz, zx, xy]. + virial_sum = atomic_virial.sum(axis=0) + virial_sum_ref = np.array([ + 23.62459886, + 1.33693361, + 1.33694608, + -1.66702306, + -0.21981908, + -1.66734152, + ]) + assert np.allclose(virial_sum, virial_sum_ref, atol=1e-4) + + stress_internal_from_virial = virial_sum / atoms.get_volume() + stress_ase_from_virial = -stress_internal_from_virial[[0, 1, 2, 4, 5, 3]] + + assert np.allclose(calc.results["stress"], stress_ase_from_virial, atol=1e-5) From ab29a551c615c4d519497d12b81dcc5ef93d2e5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 10:20:02 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sevenn/nn/force_output.py | 1 - sevenn/scripts/deploy.py | 5 ++++- tests/unit_tests/test_atomic_virial.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sevenn/nn/force_output.py b/sevenn/nn/force_output.py index fdbbcbb0..0ee2073b 100644 --- a/sevenn/nn/force_output.py +++ b/sevenn/nn/force_output.py @@ -228,4 +228,3 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) return data - diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 66415287..6383167f 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -22,7 +22,10 @@ def deploy( atomic_virial: bool = False, ) -> None: from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ForceStressOutput, ForceStressOutputFromEdge + from sevenn.nn.force_output import ( + ForceStressOutput, + ForceStressOutputFromEdge, + ) cp = load_checkpoint(checkpoint) diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py index 52d37bcf..18d90464 100644 --- a/tests/unit_tests/test_atomic_virial.py +++ b/tests/unit_tests/test_atomic_virial.py @@ -2,29 +2,29 @@ from ase.build import bulk import sevenn._keys as KEY -from sevenn.scripts.deploy import deploy from sevenn.calculator import SevenNetCalculator +from sevenn.scripts.deploy import deploy from sevenn.util import pretrained_name_to_path def _get_atoms_pbc(): - atoms = bulk("NaCl", "rocksalt", a=5.63) + atoms = bulk('NaCl', 'rocksalt', a=5.63) atoms.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) atoms.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) return atoms def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): - model_path = str(tmp_path / "7net_0_atomic_virial.pt") - deploy(pretrained_name_to_path("7net-0_11July2024"), model_path, atomic_virial=True) + model_path = str(tmp_path / '7net_0_atomic_virial.pt') + deploy(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) - calc = SevenNetCalculator(model_path, file_type="torchscript") + calc = SevenNetCalculator(model_path, file_type='torchscript') atoms = _get_atoms_pbc() atoms.calc = calc _ = atoms.get_potential_energy() - assert "stresses" in calc.results - atomic_virial = np.asarray(calc.results["stresses"]) + assert 'stresses' in calc.results + atomic_virial = np.asarray(calc.results['stresses']) assert atomic_virial.shape == (len(atoms), 6) assert np.isfinite(atomic_virial).all() @@ -58,4 +58,4 @@ def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): stress_internal_from_virial = virial_sum / atoms.get_volume() stress_ase_from_virial = -stress_internal_from_virial[[0, 1, 2, 4, 5, 3]] - assert np.allclose(calc.results["stress"], stress_ase_from_virial, atol=1e-5) + assert np.allclose(calc.results['stress'], stress_ase_from_virial, atol=1e-5) From ff77fcd1ebaf78d9741df7890d17bc09646df879 Mon Sep 17 00:00:00 2001 From: LEE EUI TAE Date: Wed, 25 Mar 2026 20:24:18 +0900 Subject: [PATCH 3/7] pair_e3gnn.cpp uses pair-wise force --- sevenn/main/sevenn_get_model.py | 19 +- sevenn/pair_e3gnn/pair_e3gnn.cpp | 228 +++++++++++---------- sevenn/scripts/deploy.py | 93 +++++++-- tests/lammps_tests/scripts/stress_skel.lmp | 21 ++ tests/lammps_tests/test_lammps.py | 146 ++++++++++++- tests/unit_tests/test_atomic_virial.py | 4 +- tests/unit_tests/test_calculator.py | 4 +- 7 files changed, 376 insertions(+), 139 deletions(-) create mode 100644 tests/lammps_tests/scripts/stress_skel.lmp diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index 24e5bdf8..ac3193fd 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -114,12 +114,14 @@ def run(args): if use_mliap and get_parallel: raise ValueError('Currently, ML-IAP interface does not tested on parallel.') - if atomic_virial and not get_serial: - raise ValueError('--atomic_virial is only supported for serial deploy.') + if atomic_virial and (use_mliap or get_parallel): + raise ValueError('--atomic_virial is only supported for serial deployment.') # deploy if output_prefix is None: output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' + if atomic_virial: + output_prefix = 'deployed_model' if use_mliap: output_prefix += '_mliap' @@ -131,10 +133,10 @@ def run(args): checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) if not use_mliap: - from sevenn.scripts.deploy import deploy, deploy_parallel + from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts - if get_serial: - deploy( + if atomic_virial: + deploy_ts( checkpoint_path, output_prefix, modal, @@ -142,6 +144,13 @@ def run(args): use_oeq=use_oeq, atomic_virial=atomic_virial, ) + elif get_serial: + deploy( + checkpoint_path, + output_prefix, + modal, + use_flash=use_flash, + use_oeq=use_oeq) else: deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: diff --git a/sevenn/pair_e3gnn/pair_e3gnn.cpp b/sevenn/pair_e3gnn/pair_e3gnn.cpp index 3e99b80b..94942c6a 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn.cpp @@ -67,7 +67,7 @@ PairE3GNN::~PairE3GNN() { memory->destroy(setflag); memory->destroy(cutsq); memory->destroy(map); - memory->destroy(elements); + // memory->destroy(elements); } } @@ -77,42 +77,32 @@ void PairE3GNN::compute(int eflag, int vflag) { This compute function is ispired/modified from stress branch of pair-nequip https://github.com/mir-group/pair_nequip */ - if (eflag || vflag) ev_setup(eflag, vflag); else evflag = vflag_fdotr = 0; - if (vflag_atom) { - error->all(FLERR, "atomic stress is not supported\n"); - } - - int nlocal = list->inum; // same as nlocal - int *ilist = list->ilist; - tagint *tag = atom->tag; - std::unordered_map tag_map; + // if (vflag_atom) { + // error->all(FLERR, "atomic stress is not supported\n"); + // } if (atom->tag_consecutive() == 0) { - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag[i]; - tag_map[itag] = ii+1; - // printf("MODIFY setting %i => %i \n",itag, tag_map[itag] ); - } - } else { - //Ordered which mappling required - for (int ii = 0; ii < nlocal; ii++) { - const int itag = ilist[ii]+1; - tag_map[itag] = ii+1; - // printf("normal setting %i => %i \n",itag, tag_map[itag] ); - } + error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); } double **x = atom->x; double **f = atom->f; int *type = atom->type; - long num_atoms[1] = {nlocal}; + int nlocal = list->inum; // same as nlocal, Why? is it different from atom->nlocal? + int *ilist = list->ilist; + int inum = list->inum; - int tag2i[nlocal]; + // tag ignore PBC + tagint *tag = atom->tag; + + std::unordered_map tag_map; + std::vector graph_index_to_i(nlocal); + + long num_atoms[1] = {nlocal}; int *numneigh = list->numneigh; // j loop cond int **firstneigh = list->firstneigh; // j list @@ -125,78 +115,64 @@ void PairE3GNN::compute(int eflag, int vflag) { } const int nedges_upper_bound = bound; - float cell[3][3]; - cell[0][0] = domain->boxhi[0] - domain->boxlo[0]; - cell[0][1] = 0.0; - cell[0][2] = 0.0; - - cell[1][0] = domain->xy; - cell[1][1] = domain->boxhi[1] - domain->boxlo[1]; - cell[1][2] = 0.0; - - cell[2][0] = domain->xz; - cell[2][1] = domain->yz; - cell[2][2] = domain->boxhi[2] - domain->boxlo[2]; - - torch::Tensor inp_cell = torch::from_blob(cell, {3, 3}, FLOAT_TYPE); - torch::Tensor inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); - - torch::Tensor inp_node_type = torch::zeros({nlocal}, INTEGER_TYPE); - torch::Tensor inp_pos = torch::zeros({nlocal, 3}); - - torch::Tensor inp_cell_volume = - torch::dot(inp_cell[0], torch::cross(inp_cell[1], inp_cell[2], 0)); - - float pbc_shift_tmp[nedges_upper_bound][3]; - - auto node_type = inp_node_type.accessor(); - auto pos = inp_pos.accessor(); + std::vector node_type; + float edge_vec[nedges_upper_bound][3]; long edge_idx_src[nedges_upper_bound]; long edge_idx_dst[nedges_upper_bound]; - int nedges = 0; - - for (int ii = 0; ii < nlocal; ii++) { + for (int ii = 0; ii < inum; ii++) { + // populate tag_map of local atoms const int i = ilist[ii]; - int itag = tag_map[tag[i]]; - tag2i[itag - 1] = i; + const int itag = tag[i]; const int itype = type[i]; - node_type[itag - 1] = map[itype]; - pos[itag - 1][0] = x[i][0]; - pos[itag - 1][1] = x[i][1]; - pos[itag - 1][2] = x[i][2]; + tag_map[itag] = ii; + graph_index_to_i[ii] = i; + node_type.push_back(map[itype]); } - for (int ii = 0; ii < nlocal; ii++) { + int nedges = 0; + // loop over neighbors, build graph + for (int ii = 0; ii < inum; ii++) { const int i = ilist[ii]; - int itag = tag_map[tag[i]]; + const int i_graph_idx = ii; const int *jlist = firstneigh[i]; const int jnum = numneigh[i]; for (int jj = 0; jj < jnum; jj++) { - int j = jlist[jj]; // atom over pbc is different atom - int jtag = tag_map[tag[j]]; // atom over pbs is same atom (it starts from 1) + int j = jlist[jj]; + const int jtag = tag[j]; j &= NEIGHMASK; - const int jtype = type[j]; + const auto found = tag_map.find(jtag); + if (found == tag_map.end()) continue; + const int j_graph_idx = found->second; + + // we have to calculate Rij to check cutoff in lammps side const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], x[j][2] - x[i][2]}; const double Rij = delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; - if (Rij < cutoff_square) { - edge_idx_src[nedges] = itag - 1; - edge_idx_dst[nedges] = jtag - 1; - - pbc_shift_tmp[nedges][0] = x[j][0] - pos[jtag - 1][0]; - pbc_shift_tmp[nedges][1] = x[j][1] - pos[jtag - 1][1]; - pbc_shift_tmp[nedges][2] = x[j][2] - pos[jtag - 1][2]; + if (Rij < cutoff_square) { + // if given j is not inside cutoff + if (nedges >= nedges_upper_bound) { + error->all(FLERR, "nedges exceeded nedges_upper_bound"); + } + edge_idx_src[nedges] = i_graph_idx; + edge_idx_dst[nedges] = j_graph_idx; + edge_vec[nedges][0] = delij[0]; + edge_vec[nedges][1] = delij[1]; + edge_vec[nedges][2] = delij[2]; nedges++; } } // j loop end } // i loop end + // convert data to Tensor + auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE); + auto inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); + auto edge_idx_src_tensor = torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); auto edge_idx_dst_tensor = @@ -204,66 +180,102 @@ void PairE3GNN::compute(int eflag, int vflag) { auto inp_edge_index = torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); - // r' = r + {shift_tensor(integer vector of len 3)} @ cell_tensor - // shift_tensor = (cell_tensor)^-1^T @ (r' - r) - torch::Tensor cell_inv_tensor = - inp_cell.inverse().transpose(0, 1).unsqueeze(0).to(device); - torch::Tensor pbc_shift_tmp_tensor = - torch::from_blob(pbc_shift_tmp, {nedges, 3}, FLOAT_TYPE) - .view({nedges, 3, 1}) - .to(device); - torch::Tensor inp_cell_shift = - torch::bmm(cell_inv_tensor.expand({nedges, 3, 3}), pbc_shift_tmp_tensor) - .view({nedges, 3}); - - inp_pos.set_requires_grad(true); - - c10::Dict input_dict; + auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE); + if (print_info) { + std::cout << " Nlocal: " << nlocal << std::endl; + std::cout << " Nedges: " << nedges << "\n" << std::endl; + } + + auto edge_vec_device = inp_edge_vec.to(device); + edge_vec_device.set_requires_grad(true); + + torch::Dict input_dict; input_dict.insert("x", inp_node_type.to(device)); - input_dict.insert("pos", inp_pos.to(device)); input_dict.insert("edge_index", inp_edge_index.to(device)); + input_dict.insert("edge_vec", edge_vec_device); input_dict.insert("num_atoms", inp_num_atoms.to(device)); - input_dict.insert("cell_lattice_vectors", inp_cell.to(device)); - input_dict.insert("cell_volume", inp_cell_volume.to(device)); - input_dict.insert("pbc_shift", inp_cell_shift); + input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU)); std::vector input(1, input_dict); auto output = model.forward(input).toGenericDict(); - torch::Tensor total_energy_tensor = - output.at("inferred_total_energy").toTensor().cpu(); - torch::Tensor force_tensor = output.at("inferred_force").toTensor().cpu(); + torch::Tensor energy_tensor = + output.at("inferred_total_energy").toTensor().squeeze(); + + // dE_dr + auto grads = torch::autograd::grad({energy_tensor}, {edge_vec_device}); + torch::Tensor dE_dr = grads[0].to(torch::kCPU); + + eng_vdwl += energy_tensor.detach().to(torch::kCPU).item(); + torch::Tensor force_tensor = torch::zeros({nlocal, 3}); + + auto _edge_idx_src_tensor = + edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}); + auto _edge_idx_dst_tensor = + edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}); + + force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum"); + force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr), + "sum"); + auto forces = force_tensor.accessor(); - eng_vdwl += total_energy_tensor.item(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - f[i][0] += forces[itag][0]; - f[i][1] += forces[itag][1]; - f[i][2] += forces[itag][2]; + for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) { + int i = graph_index_to_i[graph_idx]; + f[i][0] += forces[graph_idx][0]; + f[i][1] += forces[graph_idx][1]; + f[i][2] += forces[graph_idx][2]; } + // Virial stress from edge contributions if (vflag) { - // more accurately, it is virial part of stress - torch::Tensor stress_tensor = output.at("inferred_stress").toTensor().cpu(); - auto virial_stress_tensor = stress_tensor * inp_cell_volume; - // xy yz zx order in vasp (voigt is xx yy zz yz xz xy) + auto diag = inp_edge_vec * dE_dr; + auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1); + auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2); + auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0); + std::vector voigt_list = { + diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)}; + auto voigt = torch::cat(voigt_list, 1); + + torch::Tensor per_atom_stress_tensor = torch::zeros({nlocal, 6}); + auto _edge_idx_dst6_tensor = + edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6}); + per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt, + "sum"); + + auto virial_stress_tensor = + torch::neg(torch::sum(per_atom_stress_tensor, 0)); auto virial_stress = virial_stress_tensor.accessor(); + virial[0] += virial_stress[0]; virial[1] += virial_stress[1]; virial[2] += virial_stress[2]; virial[3] += virial_stress[3]; virial[4] += virial_stress[5]; virial[5] += virial_stress[4]; + + if (vflag_atom) { + auto per_atom_stress = per_atom_stress_tensor.accessor(); + + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + vatom[i][0] += -per_atom_stress[gi][0]; + vatom[i][1] += -per_atom_stress[gi][1]; + vatom[i][2] += -per_atom_stress[gi][2]; + vatom[i][3] += -per_atom_stress[gi][3]; + vatom[i][4] += -per_atom_stress[gi][5]; + vatom[i][5] += -per_atom_stress[gi][4]; + } + } } if (eflag_atom) { torch::Tensor atomic_energy_tensor = - output.at("atomic_energy").toTensor().cpu().view({nlocal}); + output.at("atomic_energy").toTensor().to(torch::kCPU).view({nlocal}); auto atomic_energy = atomic_energy_tensor.accessor(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - eatom[i] += atomic_energy[itag]; + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + eatom[i] += atomic_energy[gi]; } } diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 6383167f..573ccc78 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -21,14 +21,7 @@ def deploy( use_oeq: bool = False, atomic_virial: bool = False, ) -> None: - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ( - ForceStressOutput, - ForceStressOutputFromEdge, - ) - cp = load_checkpoint(checkpoint) - model, config = ( cp.build_model( enable_cueq=False, @@ -39,14 +32,8 @@ def deploy( cp.config, ) - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - if atomic_virial: - grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) - else: - grad_module = ForceStressOutput() - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key + if 'force_output' in model._modules: + model.delete_module_by_key('force_output') if hasattr(model, 'eval_type_map'): setattr(model, 'eval_type_map', False) @@ -183,3 +170,79 @@ def deploy_parallel( model = torch.jit.freeze(model) torch.jit.save(model, fname_full, _extra_files=md_configs) + + +def deploy_ts( + checkpoint: Union[pathlib.Path, str], + fname='deployed_model.pt', + modal: Optional[str] = None, + use_flash: bool = False, + use_oeq: bool = False, + atomic_virial: bool = False, +) -> None: + ''' + only for SevenNetCalculator with torchscript input (not for e3gnn) + ''' + from sevenn.nn.edge_embedding import EdgePreprocess + from sevenn.nn.force_output import ( + ForceStressOutput, + ForceStressOutputFromEdge, + ) + + cp = load_checkpoint(checkpoint) + + model, config = ( + cp.build_model( + enable_cueq=False, + enable_flash=use_flash, + enable_oeq=use_oeq, + _flash_lammps=use_flash, + ), + cp.config, + ) + + model.prepand_module('edge_preprocess', EdgePreprocess(True)) + grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) + model.replace_module('force_output', grad_module) + new_grad_key = grad_module.get_grad_key() + model.key_grad = new_grad_key + + if hasattr(model, 'eval_type_map'): + setattr(model, 'eval_type_map', False) + + if modal: + model.prepare_modal_deploy(modal) + elif model.modal_map is not None and len(model.modal_map) >= 1: + raise ValueError( + f'Modal is not given. It has: {list(model.modal_map.keys())}' + ) + + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + # make some config need for md + md_configs = {} + type_map = config[KEY.TYPE_MAP] + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update({'flashTP': 'yes' if use_flash else 'no'}) + md_configs.update({'oeq': 'yes' if use_oeq else 'no'}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + md_configs.update({'atomic_virial': 'yes' if atomic_virial else 'no'}) + + if fname.endswith('.pt') is False: + fname += '.pt' + torch.jit.save(model, fname, _extra_files=md_configs) diff --git a/tests/lammps_tests/scripts/stress_skel.lmp b/tests/lammps_tests/scripts/stress_skel.lmp new file mode 100644 index 00000000..10b69e8c --- /dev/null +++ b/tests/lammps_tests/scripts/stress_skel.lmp @@ -0,0 +1,21 @@ + units metal + boundary __BOUNDARY__ + read_data __LMP_STCT__ + + mass * 1.0 # do not matter since we don't run MD + + pair_style __PAIR_STYLE__ + pair_coeff * * __POTENTIALS__ __ELEMENT__ + + timestep 0.002 + + compute pa all pe/atom + compute astress all stress/atom NULL pair + + thermo 1 + fix 1 all nve + thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp + dump mydump all custom 1 __FORCE_DUMP_PATH__ id type element c_pa c_astress[1] c_astress[2] c_astress[3] c_astress[4] c_astress[5] c_astress[6] x y z fx fy fz + dump_modify mydump sort id element __ELEMENT__ + + run 0 diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index 8847feea..502fd3e1 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -18,7 +18,7 @@ from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available from sevenn.nn.oeq_helper import is_oeq_available -from sevenn.scripts.deploy import deploy, deploy_parallel +from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts from sevenn.util import chemical_species_preprocess, pretrained_name_to_path logger = logging.getLogger('test_lammps') @@ -28,6 +28,9 @@ lmp_script_path = str( (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() ) +lmp_stress_script_path = str( + (pathlib.Path(__file__).parent / 'scripts' / 'stress_skel.lmp').resolve() +) data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O @@ -106,6 +109,22 @@ def ref_modal_calculator(): return SevenNetCalculator(cp_mf_path, modal='PBE') +@pytest.fixture(scope='module') +def ref_stress_calculator(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') + pot_path = str(tmp / 'deployed_atomic_virial.pt') + deploy_ts(cp_0_path, pot_path, atomic_virial=True) + return SevenNetCalculator(pot_path, file_type='torchscript') + + +@pytest.fixture(scope='module') +def ref_7net0_stress_calculator(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') + pot_path = str(tmp / 'deployed_atomic_virial.pt') + deploy_ts(cp_7net0_path, pot_path, atomic_virial=True) + return SevenNetCalculator(pot_path, file_type='torchscript') + + def get_model_config(): config = { 'cutoff': cutoff, @@ -176,7 +195,7 @@ def get_system(system_name, **kwargs): raise ValueError() -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6, check_atomic_stress=False): def acl(a, b, rtol=rtol, atol=atol): return np.allclose(a, b, rtol=rtol, atol=atol) @@ -190,6 +209,11 @@ def acl(a, b, rtol=rtol, atol=atol): rtol * 10, atol * 10, ) + if check_atomic_stress: + ref_atomic_virial = np.asarray(atoms1.calc.results['stresses']) + lmp_atomic_stress = np.asarray(atoms2.calc.results['atomic_stress']) + lmp_atomic_virial = -lmp_atomic_stress[:, [0, 1, 2, 3, 5, 4]] + assert acl(ref_atomic_virial, lmp_atomic_virial, rtol * 10, atol * 10) # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) @@ -219,7 +243,7 @@ def _lammps_results_to_atoms(lammps_log, force_dump): 'lmp_dump': force_dump, } # atomic energy read - latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] + latoms.calc.results['energies'] = np.ravel(latoms.arrays['c_pa']) stress = np.array( [ [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], @@ -230,10 +254,41 @@ def _lammps_results_to_atoms(lammps_log, force_dump): stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 latoms.calc.results['stress'] = stress + if 'c_astress[1]' in latoms.arrays: + atomic_stress = np.column_stack( + [ + np.asarray(latoms.arrays['c_astress[1]']), + np.asarray(latoms.arrays['c_astress[2]']), + np.asarray(latoms.arrays['c_astress[3]']), + np.asarray(latoms.arrays['c_astress[4]']), + np.asarray(latoms.arrays['c_astress[5]']), + np.asarray(latoms.arrays['c_astress[6]']), + ] + ) + latoms.calc.results['atomic_stress'] = atomic_stress / 1602.1766208 / 1000 + return latoms -def _run_lammps(atoms, pair_style, potential, wd, command, test_name): +def _run_lammps(atoms, pair_style, potential, wd, command, test_name, script_path): + def _rotate_stress(atomic_stress, rot_mat): + out = np.empty_like(atomic_stress) + for i, s in enumerate(atomic_stress): + sigma = np.array([ + [s[0], s[3], s[4]], + [s[3], s[1], s[5]], + [s[4], s[5], s[2]] + ]) + sigma = rot_mat @ sigma @ rot_mat.T + out[i] = [ + sigma[0, 0], + sigma[1, 1], + sigma[2, 2], + sigma[0, 1], + sigma[0, 2], + sigma[1, 2] + ] + return out wd = wd.resolve() pbc = atoms.get_pbc() pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) @@ -248,7 +303,7 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_stct, atoms, prismobj=prism, specorder=chem ) - with open(lmp_script_path, 'r') as f: + with open(script_path, 'r') as f: cont = f.read() lammps_log = str(wd / 'log.lammps') @@ -276,6 +331,10 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): rot_mat = prism.rot_mat results = copy.deepcopy(lmp_atoms.calc.results) + + # SinglePointCalculator does not know atomic_stress + at_stress = results.pop('atomic_stress', None) + r_force = np.dot(results['forces'], rot_mat.T) results['forces'] = r_force if 'stress' in results: @@ -287,19 +346,33 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_atoms.set_cell(r_cell, scale_atoms=True) lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() + if at_stress is not None: + lmp_atoms.calc.results['atomic_stress'] = _rotate_stress(at_stress, rot_mat) + return lmp_atoms def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): command = lammps_cmd - return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_script_path + ) def parallel_lammps_run( atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd ): command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' - return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn/parallel', potential, wd, command, test_name, lmp_script_path + ) + + +def serial_stress_lammps_run(atoms, potential, wd, test_name, lammps_cmd): + command = lammps_cmd + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_stress_script_path + ) def subprocess_routine(cmd, name): @@ -370,6 +443,65 @@ def test_serial_flash( assert_atoms(atoms, atoms_lammps, atol=1e-5) +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress( + system, serial_potential_path, ref_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path, + wd=tmp_path, + test_name='serial stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_oeq( + system, serial_potential_path_oeq, ref_7net0_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_oeq, + wd=tmp_path, + test_name='serial oeq stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_flash_available(), reason='flash not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_flash( + system, serial_potential_path_flash, ref_7net0_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_flash, + wd=tmp_path, + test_name='serial flash stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + @pytest.mark.parametrize( 'system,ncores', [ diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py index 18d90464..3bc56237 100644 --- a/tests/unit_tests/test_atomic_virial.py +++ b/tests/unit_tests/test_atomic_virial.py @@ -3,7 +3,7 @@ import sevenn._keys as KEY from sevenn.calculator import SevenNetCalculator -from sevenn.scripts.deploy import deploy +from sevenn.scripts.deploy import deploy_ts from sevenn.util import pretrained_name_to_path @@ -16,7 +16,7 @@ def _get_atoms_pbc(): def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): model_path = str(tmp_path / '7net_0_atomic_virial.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) + deploy_ts(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) calc = SevenNetCalculator(model_path, file_type='torchscript') atoms = _get_atoms_pbc() diff --git a/tests/unit_tests/test_calculator.py b/tests/unit_tests/test_calculator.py index 9b19308d..e4aacf2e 100644 --- a/tests/unit_tests/test_calculator.py +++ b/tests/unit_tests/test_calculator.py @@ -8,7 +8,7 @@ from sevenn.calculator import D3Calculator, SevenNetCalculator from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available -from sevenn.scripts.deploy import deploy +from sevenn.scripts.deploy import deploy_ts from sevenn.util import model_from_checkpoint, pretrained_name_to_path @@ -108,7 +108,7 @@ def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): atoms_pbc.rattle(stdev=0.01, seed=42) fname = str(tmp_path / '7net_0.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), fname) + deploy_ts(pretrained_name_to_path('7net-0_11July2024'), fname) calc_script = SevenNetCalculator(fname, file_type='torchscript') calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) From a397f7363cf6b26a232799c128ea280c6e171554 Mon Sep 17 00:00:00 2001 From: lee2tae Date: Fri, 27 Mar 2026 17:10:28 +0900 Subject: [PATCH 4/7] deploy_ts removed / pytest modified --- sevenn/main/sevenn_get_model.py | 36 +- sevenn/main/sevenn_patch_lammps.py | 15 - sevenn/pair_e3gnn/pair_e3gnn.cpp | 3 - .../pair_e3gnn/pair_e3gnn_atomic_stress.cpp | 444 --------- .../pair_e3gnn_parallel_atomic_stress.cpp | 913 ------------------ sevenn/pair_e3gnn/patch_lammps.sh | 14 +- sevenn/scripts/deploy.py | 78 -- tests/lammps_tests/test_lammps.py | 40 +- tests/unit_tests/test_atomic_virial.py | 61 -- tests/unit_tests/test_calculator.py | 21 - 10 files changed, 35 insertions(+), 1590 deletions(-) delete mode 100644 sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp delete mode 100644 sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp delete mode 100644 tests/unit_tests/test_atomic_virial.py diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index ac3193fd..5dd373ce 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -64,15 +64,6 @@ def add_args(parser): help='Use LAMMPS ML-IAP interface.', action='store_true', ) - ag.add_argument( - '--atomic_virial', - help=( - 'Serial deploy only: append per-atom virial output ' - '(inferred_atomic_virial) to TorchScript. This marks model-side ' - 'atomic virial capability for downstream calculator usage.' - ), - action='store_true', - ) def run(args): @@ -114,14 +105,9 @@ def run(args): if use_mliap and get_parallel: raise ValueError('Currently, ML-IAP interface does not tested on parallel.') - if atomic_virial and (use_mliap or get_parallel): - raise ValueError('--atomic_virial is only supported for serial deployment.') - # deploy if output_prefix is None: output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' - if atomic_virial: - output_prefix = 'deployed_model' if use_mliap: output_prefix += '_mliap' @@ -133,24 +119,10 @@ def run(args): checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) if not use_mliap: - from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts - - if atomic_virial: - deploy_ts( - checkpoint_path, - output_prefix, - modal, - use_flash=use_flash, - use_oeq=use_oeq, - atomic_virial=atomic_virial, - ) - elif get_serial: - deploy( - checkpoint_path, - output_prefix, - modal, - use_flash=use_flash, - use_oeq=use_oeq) + from sevenn.scripts.deploy import deploy, deploy_parallel + + if get_serial: + deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) else: deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: diff --git a/sevenn/main/sevenn_patch_lammps.py b/sevenn/main/sevenn_patch_lammps.py index 3ace2586..e90b1c78 100644 --- a/sevenn/main/sevenn_patch_lammps.py +++ b/sevenn/main/sevenn_patch_lammps.py @@ -36,11 +36,6 @@ def add_args(parser): help='Enable OpenEquivariance', action='store_true', ) - ag.add_argument( - '--atomic_stress', - help='Patch pair_e3gnn with atomic-stress enabled source files.', - action='store_true', - ) # cxx_standard is detected automatically @@ -59,12 +54,6 @@ def run(args): d3_support = '0' print(' - D3 support disabled') - atomic_stress = '1' if args.atomic_stress else '0' - if args.atomic_stress: - print(' - Atomic stress patch enabled') - else: - print(' - Atomic stress patch disabled') - so_oeq = '' if args.enable_oeq: try: @@ -122,10 +111,6 @@ def run(args): if args.enable_oeq: assert osp.isfile(so_oeq) cmd += f' {so_oeq}' - else: - cmd += ' NONE' - - cmd += f' {atomic_stress}' res = subprocess.run(cmd.split()) return res.returncode # is it meaningless? diff --git a/sevenn/pair_e3gnn/pair_e3gnn.cpp b/sevenn/pair_e3gnn/pair_e3gnn.cpp index 94942c6a..9d833c83 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn.cpp @@ -81,9 +81,6 @@ void PairE3GNN::compute(int eflag, int vflag) { ev_setup(eflag, vflag); else evflag = vflag_fdotr = 0; - // if (vflag_atom) { - // error->all(FLERR, "atomic stress is not supported\n"); - // } if (atom->tag_consecutive() == 0) { error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); diff --git a/sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp b/sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp deleted file mode 100644 index ea161337..00000000 --- a/sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp +++ /dev/null @@ -1,444 +0,0 @@ -/* ---------------------------------------------------------------------- - LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator - https://lammps.sandia.gov/, Sandia National Laboratories - Steve Plimpton, sjplimp@sandia.gov - - Copyright (2003) Sandia Corporation. Under the terms of Contract - DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains - certain rights in this software. This software is distributed under - the GNU General Public License. - - See the README file in the top-level LAMMPS directory. -------------------------------------------------------------------------- */ - -/* ---------------------------------------------------------------------- - Contributing author: Yutack Park (SNU) -------------------------------------------------------------------------- */ - -#include -#include -#include -#include - -#include -#include - -#include "atom.h" -#include "domain.h" -#include "error.h" -#include "force.h" -#include "memory.h" -#include "neigh_list.h" -#include "neigh_request.h" -#include "neighbor.h" - -#include "pair_e3gnn.h" - -using namespace LAMMPS_NS; - -// Undefined reference; body in pair_e3gnn_oeq_autograd.cpp to be linked -extern void pair_e3gnn_oeq_register_autograd(); - -#define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64) -#define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat) - -PairE3GNN::PairE3GNN(LAMMPS *lmp) : Pair(lmp) { - // constructor - // Virial is accumulated directly from the model; skip LAMMPS fdotr post-pass. - no_virial_fdotr_compute = 1; - - const char *print_flag = std::getenv("SEVENN_PRINT_INFO"); - if (print_flag) - print_info = true; - - std::string device_name; - if (torch::cuda::is_available()) { - device = torch::kCUDA; - device_name = "CUDA"; - } else { - device = torch::kCPU; - device_name = "CPU"; - } - - if (lmp->logfile) { - fprintf(lmp->logfile, "PairE3GNN using device : %s\n", device_name.c_str()); - } -} - -PairE3GNN::~PairE3GNN() { - if (allocated) { - memory->destroy(setflag); - memory->destroy(cutsq); - memory->destroy(map); - memory->destroy(elements); - } -} - -void PairE3GNN::compute(int eflag, int vflag) { - // compute - /* - This compute function is ispired/modified from stress branch of pair-nequip - https://github.com/mir-group/pair_nequip - */ - - if (eflag || vflag) - ev_setup(eflag, vflag); - else - evflag = vflag_fdotr = 0; -// if (vflag_atom) { -// error->all(FLERR, "atomic stress is not supported\n"); -// } - - int nlocal = list->inum; // same as nlocal - int *ilist = list->ilist; - tagint *tag = atom->tag; - std::unordered_map tag_map; - - if (atom->tag_consecutive() == 0) { - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag[i]; - tag_map[itag] = ii+1; - // printf("MODIFY setting %i => %i \n",itag, tag_map[itag] ); - } - } else { - //Ordered which mappling required - for (int ii = 0; ii < nlocal; ii++) { - const int itag = ilist[ii]+1; - tag_map[itag] = ii+1; - // printf("normal setting %i => %i \n",itag, tag_map[itag] ); - } - } - - double **x = atom->x; - double **f = atom->f; - int *type = atom->type; - long num_atoms[1] = {nlocal}; - - int tag2i[nlocal]; - - int *numneigh = list->numneigh; // j loop cond - int **firstneigh = list->firstneigh; // j list - - int bound; - if (this->nedges_bound == -1) { - bound = std::accumulate(numneigh, numneigh + nlocal, 0); - } else { - bound = this->nedges_bound; - } - const int nedges_upper_bound = bound; - - float cell[3][3]; - cell[0][0] = domain->boxhi[0] - domain->boxlo[0]; - cell[0][1] = 0.0; - cell[0][2] = 0.0; - - cell[1][0] = domain->xy; - cell[1][1] = domain->boxhi[1] - domain->boxlo[1]; - cell[1][2] = 0.0; - - cell[2][0] = domain->xz; - cell[2][1] = domain->yz; - cell[2][2] = domain->boxhi[2] - domain->boxlo[2]; - - torch::Tensor inp_cell = torch::from_blob(cell, {3, 3}, FLOAT_TYPE); - torch::Tensor inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); - - torch::Tensor inp_node_type = torch::zeros({nlocal}, INTEGER_TYPE); - torch::Tensor inp_pos = torch::zeros({nlocal, 3}); - - torch::Tensor inp_cell_volume = - torch::dot(inp_cell[0], torch::cross(inp_cell[1], inp_cell[2], 0)); - - float pbc_shift_tmp[nedges_upper_bound][3]; - - auto node_type = inp_node_type.accessor(); - auto pos = inp_pos.accessor(); - - long edge_idx_src[nedges_upper_bound]; - long edge_idx_dst[nedges_upper_bound]; - - int nedges = 0; - - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag_map[tag[i]]; - tag2i[itag - 1] = i; - const int itype = type[i]; - node_type[itag - 1] = map[itype]; - pos[itag - 1][0] = x[i][0]; - pos[itag - 1][1] = x[i][1]; - pos[itag - 1][2] = x[i][2]; - } - - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag_map[tag[i]]; - const int *jlist = firstneigh[i]; - const int jnum = numneigh[i]; - - for (int jj = 0; jj < jnum; jj++) { - int j = jlist[jj]; // atom over pbc is different atom - int jtag = tag_map[tag[j]]; // atom over pbs is same atom (it starts from 1) - j &= NEIGHMASK; - const int jtype = type[j]; - - const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], - x[j][2] - x[i][2]}; - const double Rij = - delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; - if (Rij < cutoff_square) { - edge_idx_src[nedges] = itag - 1; - edge_idx_dst[nedges] = jtag - 1; - - pbc_shift_tmp[nedges][0] = x[j][0] - pos[jtag - 1][0]; - pbc_shift_tmp[nedges][1] = x[j][1] - pos[jtag - 1][1]; - pbc_shift_tmp[nedges][2] = x[j][2] - pos[jtag - 1][2]; - - nedges++; - } - } // j loop end - } // i loop end - - auto edge_idx_src_tensor = - torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); - auto edge_idx_dst_tensor = - torch::from_blob(edge_idx_dst, {nedges}, INTEGER_TYPE); - auto inp_edge_index = - torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); - - // r' = r + {shift_tensor(integer vector of len 3)} @ cell_tensor - // shift_tensor = (cell_tensor)^-1^T @ (r' - r) - torch::Tensor cell_inv_tensor = - inp_cell.inverse().transpose(0, 1).unsqueeze(0).to(device); - torch::Tensor pbc_shift_tmp_tensor = - torch::from_blob(pbc_shift_tmp, {nedges, 3}, FLOAT_TYPE) - .view({nedges, 3, 1}) - .to(device); - torch::Tensor inp_cell_shift = - torch::bmm(cell_inv_tensor.expand({nedges, 3, 3}), pbc_shift_tmp_tensor) - .view({nedges, 3}); - - inp_pos.set_requires_grad(true); - - c10::Dict input_dict; - input_dict.insert("x", inp_node_type.to(device)); - input_dict.insert("pos", inp_pos.to(device)); - input_dict.insert("edge_index", inp_edge_index.to(device)); - input_dict.insert("num_atoms", inp_num_atoms.to(device)); - input_dict.insert("cell_lattice_vectors", inp_cell.to(device)); - input_dict.insert("cell_volume", inp_cell_volume.to(device)); - input_dict.insert("pbc_shift", inp_cell_shift); - - std::vector input(1, input_dict); - auto output = model.forward(input).toGenericDict(); - - torch::Tensor total_energy_tensor = - output.at("inferred_total_energy").toTensor().cpu(); - torch::Tensor force_tensor = output.at("inferred_force").toTensor().cpu(); - auto forces = force_tensor.accessor(); - eng_vdwl += total_energy_tensor.item(); - - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - f[i][0] += forces[itag][0]; - f[i][1] += forces[itag][1]; - f[i][2] += forces[itag][2]; - } - - if (vflag) { - // more accurately, it is virial part of stress - torch::Tensor stress_tensor = output.at("inferred_stress").toTensor().cpu(); - bool has_atomic_virial = false; - torch::Tensor stress_virial_tensor; - try { - stress_virial_tensor = output.at("inferred_atomic_virial").toTensor().cpu(); - has_atomic_virial = true; - } catch (...) { - has_atomic_virial = false; - } - auto virial_stress_tensor = stress_tensor * inp_cell_volume; - if (vflag_atom && !has_atomic_virial) { - error->warning( - FLERR, - "Model output has no inferred_atomic_virial; " - "stress/atom consistency test cannot be completed"); - } - // xy yz zx order in vasp (voigt is xx yy zz yz xz xy) - auto virial_stress = virial_stress_tensor.accessor(); - virial[0] += virial_stress[0]; - virial[1] += virial_stress[1]; - virial[2] += virial_stress[2]; - virial[3] += virial_stress[3]; - virial[4] += virial_stress[5]; - virial[5] += virial_stress[4]; - - if (vflag_atom && has_atomic_virial) { - auto atomic_virial = stress_virial_tensor.accessor(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - vatom[i][0] += atomic_virial[itag][0]; - vatom[i][1] += atomic_virial[itag][1]; - vatom[i][2] += atomic_virial[itag][2]; - vatom[i][3] += atomic_virial[itag][3]; - vatom[i][4] += atomic_virial[itag][5]; - vatom[i][5] += atomic_virial[itag][4]; - } - } - } - - if (eflag_atom) { - torch::Tensor atomic_energy_tensor = - output.at("atomic_energy").toTensor().cpu().view({nlocal}); - auto atomic_energy = atomic_energy_tensor.accessor(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - eatom[i] += atomic_energy[itag]; - } - } - - // if it was the first MD step - if (this->nedges_bound == -1) { - this->nedges_bound = nedges * 1.2; - } // else if the nedges is too small, increase the bound - else if (nedges > this->nedges_bound / 1.2) { - this->nedges_bound = nedges * 1.2; - } -} - -// allocate arrays (called from coeff) -void PairE3GNN::allocate() { - allocated = 1; - int n = atom->ntypes; - - memory->create(setflag, n + 1, n + 1, "pair:setflag"); - memory->create(cutsq, n + 1, n + 1, "pair:cutsq"); - memory->create(map, n + 1, "pair:map"); -} - -// global settings for pair_style -void PairE3GNN::settings(int narg, char **arg) { - if (narg != 0) { - error->all(FLERR, "Illegal pair_style command"); - } -} - -void PairE3GNN::coeff(int narg, char **arg) { - - if (allocated) { - error->all(FLERR, "pair_e3gnn coeff called twice"); - } - allocate(); - - if (strcmp(arg[0], "*") != 0 || strcmp(arg[1], "*") != 0) { - error->all(FLERR, - "e3gnn: first and second input of pair_coeff should be '*'"); - } - // expected input : pair_coeff * * pot.pth type_name1 type_name2 ... - - std::unordered_map meta_dict = { - {"chemical_symbols_to_index", ""}, - {"cutoff", ""}, - {"num_species", ""}, - {"model_type", ""}, - {"version", ""}, - {"dtype", ""}, - {"flashTP", "version mismatch"}, - {"oeq", "version mismatch"}, - {"time", ""}}; - - // model loading from input - try { - model = torch::jit::load(std::string(arg[2]), device, meta_dict); - } catch (const c10::Error &e) { - error->all(FLERR, "error loading the model, check the path of the model"); - } - // model = torch::jit::freeze(model); model is already freezed - - torch::jit::setGraphExecutorOptimize(false); - torch::jit::FusionStrategy strategy; - // thing about dynamic recompile as tensor shape varies, this is default - // strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}}; - strategy = {{torch::jit::FusionBehavior::STATIC, 0}}; - torch::jit::setFusionStrategy(strategy); - - cutoff = std::stod(meta_dict["cutoff"]); - cutoff_square = cutoff * cutoff; - - // to make torch::autograd::grad() works - if (meta_dict["oeq"] == "yes") { - pair_e3gnn_oeq_register_autograd(); - } - - if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) { - error->all(FLERR, "given model type is not E3_equivariant_model"); - } - - std::string chem_str = meta_dict["chemical_symbols_to_index"]; - int ntypes = atom->ntypes; - - auto delim = " "; - char *tok = std::strtok(const_cast(chem_str.c_str()), delim); - std::vector chem_vec; - while (tok != nullptr) { - chem_vec.push_back(std::string(tok)); - tok = std::strtok(nullptr, delim); - } - - bool found_flag = false; - for (int i = 3; i < narg; i++) { - found_flag = false; - for (int j = 0; j < chem_vec.size(); j++) { - if (chem_vec[j].compare(arg[i]) == 0) { - map[i - 2] = j; - found_flag = true; - fprintf(lmp->logfile, "Chemical specie '%s' is assigned to type %d\n", - arg[i], i - 2); - break; - } - } - if (!found_flag) { - error->all(FLERR, "Unknown chemical specie is given"); - } - } - - if (ntypes > narg - 3) { - error->all(FLERR, "Not enough chemical specie is given. Check pair_coeff " - "and types in your data/script"); - } - - for (int i = 1; i <= ntypes; i++) { - for (int j = 1; j <= ntypes; j++) { - if ((map[i] >= 0) && (map[j] >= 0)) { - setflag[i][j] = 1; - cutsq[i][j] = cutoff * cutoff; - } - } - } - - if (lmp->logfile) { - fprintf(lmp->logfile, "from sevenn version '%s' ", - meta_dict["version"].c_str()); - fprintf(lmp->logfile, "%s precision model, deployed: %s\n", - meta_dict["dtype"].c_str(), meta_dict["time"].c_str()); - fprintf(lmp->logfile, "FlashTP: %s\n", - meta_dict["flashTP"].c_str()); - fprintf(lmp->logfile, "OEQ: %s\n", - meta_dict["oeq"].c_str()); - } -} - -// init specific to this pair -void PairE3GNN::init_style() { - // Newton flag is irrelevant if use only one processor for simulation - /* - if (force->newton_pair == 0) { - error->all(FLERR, "Pair style nn requires newton pair on"); - } - */ - - // full neighbor list (this is many-body potential) - neighbor->add_request(this, NeighConst::REQ_FULL); -} - -double PairE3GNN::init_one(int i, int j) { return cutoff; } diff --git a/sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp b/sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp deleted file mode 100644 index 2ec93909..00000000 --- a/sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp +++ /dev/null @@ -1,913 +0,0 @@ -/* ---------------------------------------------------------------------- - LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator - https://lammps.sandia.gov/, Sandia National Laboratories - Steve Plimpton, sjplimp@sandia.gov - - Copyright (2003) Sandia Corporation. Under the terms of Contract - DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains - certain rights in this software. This software is distributed under - the GNU General Public License. - - See the README file in the top-level LAMMPS directory. -------------------------------------------------------------------------- */ - -/* ---------------------------------------------------------------------- - Contributing author: Yutack Park (SNU) -------------------------------------------------------------------------- */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include - -#include "atom.h" -#include "comm.h" -#include "comm_brick.h" -#include "error.h" -#include "force.h" -#include "memory.h" -#include "neigh_list.h" -#include "neighbor.h" -// #include "nvToolsExt.h" - -#include "pair_e3gnn_parallel.h" -#include - -#ifdef OMPI_MPI_H -#include "mpi-ext.h" //This should be included after mpi.h which is included in pair.h -#endif - -using namespace LAMMPS_NS; - -// Undefined reference; body in pair_e3gnn_oeq_autograd.cpp to be linked -extern void pair_e3gnn_oeq_register_autograd(); - -#define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64) -#define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat) - -DeviceBuffManager &DeviceBuffManager::getInstance() { - static DeviceBuffManager instance; - return instance; -} - -void DeviceBuffManager::get_buffer(int send_size, int recv_size, - float *&buf_send_ptr, float *&buf_recv_ptr) { - if (send_size > send_buf_size) { - cudaFree(buf_send_device); - cudaError_t cuda_err = - cudaMalloc(&buf_send_device, send_size * sizeof(float)); - send_buf_size = send_size; - } - if (recv_size > recv_buf_size) { - cudaFree(buf_recv_device); - cudaError_t cuda_err = - cudaMalloc(&buf_recv_device, recv_size * sizeof(float)); - recv_buf_size = recv_size; - } - buf_send_ptr = buf_send_device; - buf_recv_ptr = buf_recv_device; -} - -DeviceBuffManager::~DeviceBuffManager() { - cudaFree(buf_send_device); - cudaFree(buf_recv_device); -} - -PairE3GNNParallel::PairE3GNNParallel(LAMMPS *lmp) : Pair(lmp) { - // constructor - // Virial is accumulated directly from the model; skip LAMMPS fdotr post-pass. - no_virial_fdotr_compute = 1; - - const char *print_flag = std::getenv("SEVENN_PRINT_INFO"); - const char *print_both_flag = std::getenv("SEVENN_PRINT_BOTH_INFO"); - if (print_flag) { - world_rank = comm->me; - std::cout << "process rank: " << world_rank << " initialized" << std::endl; - print_info = (world_rank == 0) || print_both_flag; - } - - std::string device_name; - const bool use_gpu = torch::cuda::is_available(); - - comm_forward = 0; - comm_reverse = 0; - - // OpenMPI detection -#ifdef OMPI_MPI_H -#if defined(MPIX_CUDA_AWARE_SUPPORT) - if (1 == MPIX_Query_cuda_support()) { - use_cuda_mpi = true; - } else { - use_cuda_mpi = false; - } -#else - use_cuda_mpi = false; -#endif -#else - use_cuda_mpi = false; -#endif - // use_cuda_mpi = use_gpu && use_cuda_mpi; - // if (use_cuda_mpi) { - if (use_gpu) { - device = get_cuda_device(); - device_name = "CUDA"; - } else { - device = torch::kCPU; - device_name = "CPU"; - } - - if (std::getenv("OFF_E3GNN_PARALLEL_CUDA_MPI")) { - use_cuda_mpi = false; - } - - if (lmp->screen) { - if (use_gpu && !use_cuda_mpi) { - device_comm = torch::kCPU; - fprintf(lmp->screen, - "cuda-aware mpi not found, communicate via host device\n"); - } else { - device_comm = device; - } - fprintf(lmp->screen, "PairE3GNNParallel using device : %s\n", - device_name.c_str()); - fprintf(lmp->screen, "PairE3GNNParallel cuda-aware mpi: %s\n", - use_cuda_mpi ? "True" : "False"); - } - if (lmp->logfile) { - if (use_gpu && !use_cuda_mpi) { - device_comm = torch::kCPU; - fprintf(lmp->logfile, - "cuda-aware mpi not found, communicate via host device\n"); - } else { - device_comm = device; - } - fprintf(lmp->logfile, "PairE3GNNParallel using device : %s\n", - device_name.c_str()); - fprintf(lmp->logfile, "PairE3GNNParallel cuda-aware mpi: %s\n", - use_cuda_mpi ? "True" : "False"); - } -} - -torch::Device PairE3GNNParallel::get_cuda_device() { - char *cuda_visible = std::getenv("CUDA_VISIBLE_DEVICES"); - int num_gpus; - int idx; - int rank = comm->me; - num_gpus = torch::cuda::device_count(); - idx = rank % num_gpus; - if (print_info) - std::cout << world_rank << " Available # of GPUs found: " << num_gpus - << std::endl; - cudaError_t cuda_err = cudaSetDevice(idx); - if (cuda_err != cudaSuccess) { - std::cerr << "E3GNN: Failed to set CUDA device: " - << cudaGetErrorString(cuda_err) << std::endl; - } - return torch::Device(torch::kCUDA, idx); -} - -PairE3GNNParallel::~PairE3GNNParallel() { - if (allocated) { - memory->destroy(setflag); - memory->destroy(cutsq); - memory->destroy(map); - } -} - -int PairE3GNNParallel::get_x_dim() { return x_dim; } - -bool PairE3GNNParallel::use_cuda_mpi_() { return use_cuda_mpi; } - -bool PairE3GNNParallel::is_comm_preprocess_done() { - return comm_preprocess_done; -} - -void PairE3GNNParallel::compute(int eflag, int vflag) { - /* - Graph build on cpu - */ - if (eflag || vflag) - ev_setup(eflag, vflag); - else - evflag = vflag_fdotr = 0; -// if (vflag_atom) { -// error->all(FLERR, "atomic stress is not supported\n"); -// } - - if (atom->tag_consecutive() == 0) { - error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); - } - - double **x = atom->x; - double **f = atom->f; - int *type = atom->type; - int nlocal = list->inum; // same as nlocal - int nghost = atom->nghost; - int ntotal = nlocal + nghost; - int *ilist = list->ilist; - int inum = list->inum; - - CommBrick *comm_brick = dynamic_cast(comm); - if (comm_brick == nullptr) { - error->all(FLERR, "e3gnn/parallel: comm style should be brick & from " - "modified code of comm_brick"); - } - - bigint natoms = atom->natoms; - - // tag ignore PBC - tagint *tag = atom->tag; - - // store graph_idx from local to known ghost atoms(ghost atoms inside cutoff) - int tag_to_graph_idx[natoms + 1]; // tag starts from 1 not 0 - std::fill_n(tag_to_graph_idx, natoms + 1, -1); - - // to access tag_to_graph_idx from comm - tag_to_graph_idx_ptr = tag_to_graph_idx; - - int graph_indexer = nlocal; - int graph_index_to_i[ntotal]; - - int *numneigh = list->numneigh; // j loop cond - int **firstneigh = list->firstneigh; // j list - const int nedges_upper_bound = - std::accumulate(numneigh, numneigh + nlocal, 0); - - std::vector node_type; - std::vector node_type_ghost; - - float edge_vec[nedges_upper_bound][3]; - long edge_idx_src[nedges_upper_bound]; - long edge_idx_dst[nedges_upper_bound]; - - int nedges = 0; - for (int ii = 0; ii < inum; ii++) { - // populate tag_to_graph_idx of local atoms - const int i = ilist[ii]; - const int itag = tag[i]; - const int itype = type[i]; - tag_to_graph_idx[itag] = ii; - graph_index_to_i[ii] = i; - node_type.push_back(map[itype]); - } - - // loop over neighbors, build graph - for (int ii = 0; ii < inum; ii++) { - const int i = ilist[ii]; - const int i_graph_idx = ii; - const int *jlist = firstneigh[i]; - const int jnum = numneigh[i]; - - for (int jj = 0; jj < jnum; jj++) { - int j = jlist[jj]; - const int jtag = tag[j]; - j &= NEIGHMASK; - const int jtype = type[j]; - // we have to calculate Rij to check cutoff in lammps side - const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], - x[j][2] - x[i][2]}; - const double Rij = - delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; - - int j_graph_idx; - if (Rij < cutoff_square) { - // if given j is not local atom and inside cutoff - if (tag_to_graph_idx[jtag] == -1) { - // if j is ghost atom inside cutoff but first seen - tag_to_graph_idx[jtag] = graph_indexer; - graph_index_to_i[graph_indexer] = j; - node_type_ghost.push_back(map[jtype]); - graph_indexer++; - } - - j_graph_idx = tag_to_graph_idx[jtag]; - edge_idx_src[nedges] = i_graph_idx; - edge_idx_dst[nedges] = j_graph_idx; - edge_vec[nedges][0] = delij[0]; - edge_vec[nedges][1] = delij[1]; - edge_vec[nedges][2] = delij[2]; - nedges++; - } - } // j loop end - } // i loop end - - // member variable - graph_size = graph_indexer; - const int ghost_node_num = graph_size - nlocal; - - // convert data to Tensor - auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE); - auto inp_node_type_ghost = - torch::from_blob(node_type_ghost.data(), ghost_node_num, INTEGER_TYPE); - - long num_nodes[1] = {long(nlocal)}; - auto inp_num_atoms = torch::from_blob(num_nodes, {1}, INTEGER_TYPE); - - auto edge_idx_src_tensor = - torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); - auto edge_idx_dst_tensor = - torch::from_blob(edge_idx_dst, {nedges}, INTEGER_TYPE); - auto inp_edge_index = - torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); - - auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE); - if (print_info) { - std::cout << world_rank << " Nlocal: " << nlocal << std::endl; - std::cout << world_rank << " Graph_size: " << graph_size << std::endl; - std::cout << world_rank << " Ghost_node_num: " << ghost_node_num - << std::endl; - std::cout << world_rank << " Nedges: " << nedges << "\n" << std::endl; - } - - // r_original requires grad True - inp_edge_vec.set_requires_grad(true); - - torch::Dict input_dict; - input_dict.insert("x", inp_node_type.to(device)); - input_dict.insert("x_ghost", inp_node_type_ghost.to(device)); - input_dict.insert("edge_index", inp_edge_index.to(device)); - input_dict.insert("edge_vec", inp_edge_vec.to(device)); - input_dict.insert("num_atoms", inp_num_atoms.to(device)); - input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU)); - - std::list> wrt_tensors; - wrt_tensors.push_back({input_dict.at("edge_vec")}); - - auto model_part = model_list.front(); - - auto output = model_part.forward({input_dict}).toGenericDict(); - - comm_preprocess(); - - // extra_graph_idx_map is set from comm_preprocess(); - // last one is for trash values. See pack_forward_init - const int extra_size = - ghost_node_num + static_cast(extra_graph_idx_map.size()) + 1; - torch::Tensor x_local; - torch::Tensor x_ghost; - - for (auto it = model_list.begin(); it != model_list.end(); ++it) { - if (it == model_list.begin()) - continue; - model_part = *it; - - x_local = output.at("x").toTensor().detach().to(device); - x_dim = x_local.size(1); // length of per atom vector(node feature) - - auto ghost_and_extra_x = torch::zeros({ghost_node_num + extra_size, x_dim}, - FLOAT_TYPE.device(device)); - x_comm = torch::cat({x_local, ghost_and_extra_x}, 0).to(device_comm); - comm_brick->forward_comm(this); // populate x_ghost by communication - - // What we got from forward_comm (node feature of ghosts) - x_ghost = torch::split_with_sizes( - x_comm, {nlocal, ghost_node_num, extra_size}, 0)[1]; - x_ghost.set_requires_grad(true); - - // prepare next input (output > next input) - output.insert_or_assign("x_ghost", x_ghost.to(device)); - // make another edge_vec to discriminate grad calculation with other - // edge_vecs(maybe redundant?) - output.insert_or_assign("edge_vec", - output.at("edge_vec").toTensor().clone()); - - // save tensors for backprop - wrt_tensors.push_back({output.at("edge_vec").toTensor(), - output.at("x").toTensor(), - output.at("self_cont_tmp").toTensor(), - output.at("x_ghost").toTensor()}); - - output = model_part.forward({output}).toGenericDict(); - } - torch::Tensor energy_tensor = - output.at("inferred_total_energy").toTensor().squeeze(); - - torch::Tensor dE_dr = - torch::zeros({nedges, 3}, FLOAT_TYPE.device(device)); // create on device - torch::Tensor x_local_save; // holds grad info of x_local (it loses its grad - // when sends to CPU) - torch::Tensor self_conn_grads; - std::vector grads; - std::vector of_tensor; - - // TODO: most values of self_conn_grads were zero because we use only scalars - // for energy - for (auto rit = wrt_tensors.rbegin(); rit != wrt_tensors.rend(); ++rit) { - // edge_vec, x, x_ghost order - auto wrt_tensor = *rit; - if (rit == wrt_tensors.rbegin()) { - grads = torch::autograd::grad({energy_tensor}, wrt_tensor); - } else { - x_local_save.copy_(x_local); - // of wrt grads_output - grads = torch::autograd::grad(of_tensor, wrt_tensor, - {x_local_save, self_conn_grads}); - } - - dE_dr = dE_dr + grads.at(0); // accumulate force - if (std::distance(rit, wrt_tensors.rend()) == 1) - continue; // if last iteration - - of_tensor.clear(); - of_tensor.push_back(wrt_tensor[1]); // x - of_tensor.push_back(wrt_tensor[2]); // self_cont_tmp - - x_local_save = grads.at(1); // for grads_output - x_local = x_local_save.detach(); // grad_outputs & communication - x_dim = x_local.size(1); - - self_conn_grads = grads.at(2); // no communication, for grads_output - - x_ghost = grads.at(3).detach(); // yes communication, not for grads_output - - auto extra_x = torch::zeros({extra_size, x_dim}, FLOAT_TYPE.device(device)); - x_comm = torch::cat({x_local, x_ghost, extra_x}, 0).to(device_comm); - - comm_brick->reverse_comm(this); // completes x_local - - // now x_local is complete (dE_dx), become next grads_output(with - // self_conn_grads) - x_local = torch::split_with_sizes( - x_comm, {nlocal, ghost_node_num, extra_size}, 0)[0]; - } - - // postprocessing - if (print_info) { - size_t free, tot; - cudaMemGetInfo(&free, &tot); - std::cout << world_rank << " MEM use after backward(MB)" << std::endl; - double Mfree = static_cast(free) / (1024 * 1024); - double Mtot = static_cast(tot) / (1024 * 1024); - std::cout << world_rank << " Total: " << Mtot << std::endl; - std::cout << world_rank << " Free: " << Mfree << std::endl; - std::cout << world_rank << " Used: " << Mtot - Mfree << std::endl; - double Mused = Mtot - Mfree; - std::cout << world_rank << " Used/Nedges: " << Mused / nedges << std::endl; - std::cout << world_rank << " Used/Nlocal: " << Mused / nlocal << std::endl; - std::cout << world_rank << " Used/GraphSize: " << Mused / graph_size << "\n" - << std::endl; - } - eng_vdwl += energy_tensor.item(); // accumulate energy - - dE_dr = dE_dr.to(torch::kCPU); - torch::Tensor force_tensor = torch::zeros({graph_indexer, 3}); - - auto _edge_idx_src_tensor = - edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}); - auto _edge_idx_dst_tensor = - edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}); - - force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum"); - force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr), - "sum"); - - auto forces = force_tensor.accessor(); - - for (int graph_idx = 0; graph_idx < graph_indexer; graph_idx++) { - int i = graph_index_to_i[graph_idx]; - f[i][0] += forces[graph_idx][0]; - f[i][1] += forces[graph_idx][1]; - f[i][2] += forces[graph_idx][2]; - } - - if (vflag) { - auto diag = inp_edge_vec * dE_dr; - auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1); - auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2); - auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0); - std::vector voigt_list = { - diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)}; - auto voigt = torch::cat(voigt_list, 1); - - torch::Tensor per_atom_stress_tensor = torch::zeros({graph_indexer, 6}); - auto _edge_idx_dst6_tensor = - edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6}); - per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt, - "sum"); - auto virial_stress_tensor = - torch::neg(torch::sum(per_atom_stress_tensor, 0)); - auto virial_stress = virial_stress_tensor.accessor(); - - virial[0] += virial_stress[0]; - virial[1] += virial_stress[1]; - virial[2] += virial_stress[2]; - virial[3] += virial_stress[3]; - virial[4] += virial_stress[5]; - virial[5] += virial_stress[4]; - } - - if (eflag_atom) { - torch::Tensor atomic_energy_tensor = - output.at("atomic_energy").toTensor().cpu().view({nlocal}); - auto atomic_energy = atomic_energy_tensor.accessor(); - for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) { - int i = graph_index_to_i[graph_idx]; - eatom[i] += atomic_energy[graph_idx]; - } - } - - // clean up comm preprocess variables - comm_preprocess_done = false; - for (int i = 0; i < 6; i++) { - // array of vector - comm_index_pack_forward[i].clear(); - comm_index_unpack_forward[i].clear(); - comm_index_unpack_reverse[i].clear(); - } - - extra_graph_idx_map.clear(); -} - -// allocate arrays (called from coeff) -void PairE3GNNParallel::allocate() { - allocated = 1; - int n = atom->ntypes; - - memory->create(setflag, n + 1, n + 1, "pair:setflag"); - memory->create(cutsq, n + 1, n + 1, "pair:cutsq"); - memory->create(map, n + 1, "pair:map"); -} - -// global settings for pair_style -void PairE3GNNParallel::settings(int narg, char **arg) { - if (narg != 0) { - error->all(FLERR, "Illegal pair_style command"); - } -} - -void PairE3GNNParallel::coeff(int narg, char **arg) { - if (allocated) { - error->all(FLERR, "pair_e3gnn coeff called twice"); - } - allocate(); - - if (strcmp(arg[0], "*") != 0 || strcmp(arg[1], "*") != 0) { - error->all(FLERR, - "e3gnn: first and second input of pair_coeff should be '*'"); - } - // expected input : pair_coeff * * pot.pth type_name1 type_name2 ... - - std::unordered_map meta_dict = { - {"chemical_symbols_to_index", ""}, - {"cutoff", ""}, - {"num_species", ""}, - {"model_type", ""}, - {"version", ""}, - {"dtype", ""}, - {"time", ""}, - {"flashTP", "version mismatch"}, - {"oeq", "version mismatch"}, - {"comm_size", ""}}; - - // model loading from input - int n_model = std::stoi(arg[2]); - int chem_arg_i = 4; - std::vector model_fnames; - if (std::filesystem::exists(arg[3])) { - if (std::filesystem::is_directory(arg[3])) { - auto headf = std::string(arg[3]); - for (int i = 0; i < n_model; i++) { - auto stri = std::to_string(i); - model_fnames.push_back(headf + "/deployed_parallel_" + stri + ".pt"); - } - } else if (std::filesystem::is_regular_file(arg[3])) { - for (int i = 3; i < n_model + 3; i++) { - model_fnames.push_back(std::string(arg[i])); - } - chem_arg_i = n_model + 3; - } else { - error->all(FLERR, "No such file or directory:" + std::string(arg[3])); - } - } - - for (const auto &modelf : model_fnames) { - if (!std::filesystem::is_regular_file(modelf)) { - error->all(FLERR, "Expected this is a regular file:" + modelf); - } - model_list.push_back(torch::jit::load(modelf, device, meta_dict)); - } - - torch::jit::setGraphExecutorOptimize(false); - torch::jit::FusionStrategy strategy; - // strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}}; - strategy = {{torch::jit::FusionBehavior::STATIC, 0}}; - torch::jit::setFusionStrategy(strategy); - - cutoff = std::stod(meta_dict["cutoff"]); - - // maximum possible size of per atom x before last convolution - int comm_size = std::stod(meta_dict["comm_size"]); - - // to initialize buffer size for communication - comm_forward = comm_size; - comm_reverse = comm_size; - - cutoff_square = cutoff * cutoff; - - // to make torch::autograd::grad() works - if (meta_dict["oeq"] == "yes") { - pair_e3gnn_oeq_register_autograd(); - } - - if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) { - error->all(FLERR, "given model type is not E3_equivariant_model"); - } - - std::string chem_str = meta_dict["chemical_symbols_to_index"]; - int ntypes = atom->ntypes; - - auto delim = " "; - char *tok = std::strtok(const_cast(chem_str.c_str()), delim); - std::vector chem_vec; - while (tok != nullptr) { - chem_vec.push_back(std::string(tok)); - tok = std::strtok(nullptr, delim); - } - - // what if unknown chemical specie is in arg? should I abort? is there any use - // case for that? - bool found_flag = false; - int n_chem = narg - chem_arg_i; - for (int i = 0; i < n_chem; i++) { - found_flag = false; - for (int j = 0; j < chem_vec.size(); j++) { - if (chem_vec[j].compare(arg[i + chem_arg_i]) == 0) { - map[i + 1] = j; // store from 1, (not 0) - found_flag = true; - if (lmp->logfile) { - fprintf(lmp->logfile, "Chemical specie '%s' is assigned to type %d\n", - arg[i + chem_arg_i], i + 1); - break; - } - } - } - if (!found_flag) { - error->all(FLERR, "Unknown chemical specie is given or the number of " - "potential files is not consistent"); - } - } - - for (int i = 1; i <= ntypes; i++) { - for (int j = 1; j <= ntypes; j++) { - if ((map[i] >= 0) && (map[j] >= 0)) { - setflag[i][j] = 1; - cutsq[i][j] = cutoff * cutoff; - } - } - } - - if (lmp->logfile) { - fprintf(lmp->logfile, "from sevenn version '%s' ", - meta_dict["version"].c_str()); - fprintf(lmp->logfile, "%s precision model, deployed: %s\n", - meta_dict["dtype"].c_str(), meta_dict["time"].c_str()); - fprintf(lmp->logfile, "FlashTP: %s\n", - meta_dict["flashTP"].c_str()); - fprintf(lmp->logfile, "OEQ: %s\n", - meta_dict["oeq"].c_str()); - } -} - -// init specific to this pair -void PairE3GNNParallel::init_style() { - // full neighbor list & newton on - if (force->newton_pair == 0) { - error->all(FLERR, "Pair style e3gnn/parallel requires newton pair on"); - } - neighbor->add_request(this, NeighConst::REQ_FULL); -} - -double PairE3GNNParallel::init_one(int i, int j) { return cutoff; } - -void PairE3GNNParallel::notify_proc_ids(const int *sendproc, const int *recvproc) { - for (int iswap = 0; iswap < 6; iswap++) { - this->sendproc[iswap] = sendproc[iswap]; - this->recvproc[iswap]= recvproc[iswap]; - } -} - -void PairE3GNNParallel::comm_preprocess() { - assert(!comm_preprocess_done); - CommBrick *comm_brick = dynamic_cast(comm); - - // fake lammps communication call to preprocess index - // gives complete comm_index_pack, unpack_forward, and extra_graph_idx_map - comm_brick->forward_comm(this); - - std::map> already_met_map; - for (int comm_phase = 0; comm_phase < 6; comm_phase++) { - const int n = comm_index_pack_forward[comm_phase].size(); - int sproc = this->sendproc[comm_phase]; - if (already_met_map.count(sproc) == 0) { - already_met_map.insert({sproc, std::set()}); - } - - // for unpack_reverse, Ignore duplicated index by 'already_met' - std::vector &idx_map_forward = comm_index_pack_forward[comm_phase]; - std::vector &idx_map_reverse = comm_index_unpack_reverse[comm_phase]; - std::set& already_met = already_met_map[sproc]; - // the last index of x_comm is used to trash unnecessary values - const int trash_index = - graph_size + static_cast(extra_graph_idx_map.size()); //+ 1; - for (int i = 0; i < n; i++) { - const int idx = idx_map_forward[i]; - if (idx < graph_size) { - if (already_met.count(idx) == 1) { - idx_map_reverse.push_back(trash_index); - } else { - idx_map_reverse.push_back(idx); - already_met.insert(idx); - } - } else { - idx_map_reverse.push_back(idx); - } - } - - if (use_cuda_mpi) { - comm_index_pack_forward_tensor[comm_phase] = torch::from_blob(idx_map_forward.data(), idx_map_forward.size(), INTEGER_TYPE).to(device); - - auto upmap = comm_index_unpack_forward[comm_phase]; - comm_index_unpack_forward_tensor[comm_phase] = torch::from_blob(upmap.data(), upmap.size(), INTEGER_TYPE).to(device); - comm_index_unpack_reverse_tensor[comm_phase] = torch::from_blob(idx_map_reverse.data(), idx_map_reverse.size(), INTEGER_TYPE).to(device); - } - } - comm_preprocess_done = true; -} - -// called from comm_brick if comm_preprocess_done is false -void PairE3GNNParallel::pack_forward_init(int n, int *list_send, - int comm_phase) { - std::vector &idx_map = comm_index_pack_forward[comm_phase]; - - idx_map.reserve(n); - - int i, j; - int nlocal = list->inum; - tagint *tag = atom->tag; - - for (i = 0; i < n; i++) { - int list_i = list_send[i]; - int graph_idx = tag_to_graph_idx_ptr[tag[list_i]]; - - if (graph_idx != -1) { - // known atom (local atom + ghost atom inside cutoff) - idx_map.push_back(graph_idx); - } else { - // unknown atom, these are not used in computation in this process - // instead, this process is used to hand over these atoms to other proecss - // hold them in continuous manner for flexible tensor operations later - if (extra_graph_idx_map.find(list_i) != extra_graph_idx_map.end()) { - idx_map.push_back(extra_graph_idx_map[list_i]); - } else { - // unknown atom at pack forward, ghost atom outside cutoff? - extra_graph_idx_map[i] = graph_size + extra_graph_idx_map.size(); - idx_map.push_back(extra_graph_idx_map[i]); // same as list_i in pack - } - } - } -} - -// called from comm_brick if comm_preprocess_done is false -void PairE3GNNParallel::unpack_forward_init(int n, int first, int comm_phase) { - std::vector &idx_map = comm_index_unpack_forward[comm_phase]; - - idx_map.reserve(n); - - int i, j, last; - last = first + n; - int nlocal = list->inum; - tagint *tag = atom->tag; - - for (i = first; i < last; i++) { - int graph_idx = tag_to_graph_idx_ptr[tag[i]]; - if (graph_idx != -1) { - idx_map.push_back(graph_idx); - } else { - extra_graph_idx_map[i] = graph_size + extra_graph_idx_map.size(); - idx_map.push_back(extra_graph_idx_map[i]); // same as list_i in pack - } - } -} - -int PairE3GNNParallel::pack_forward_comm_gnn(float *buf, int comm_phase) { - std::vector &idx_map = comm_index_pack_forward[comm_phase]; - const int n = static_cast(idx_map.size()); - if (use_cuda_mpi && n != 0) { - torch::Tensor &idx_map_tensor = comm_index_pack_forward_tensor[comm_phase]; - auto selected = x_comm.index_select(0, idx_map_tensor); // its size is x_dim * n - cudaError_t cuda_err = - cudaMemcpy(buf, selected.data_ptr(), (x_dim * n) * sizeof(float), - cudaMemcpyDeviceToDevice); - } else { - int i, j, m; - m = 0; - for (i = 0; i < n; i++) { - const int idx = static_cast(idx_map.at(i)); - float *from = x_comm[idx].data_ptr(); - for (j = 0; j < x_dim; j++) { - buf[m++] = from[j]; - } - } - } - if (print_info) { - std::cout << world_rank << " comm_phase: " << comm_phase << std::endl; - std::cout << world_rank << " pack_forward x_dim: " << x_dim << std::endl; - std::cout << world_rank << " pack_forward n: " << n << std::endl; - std::cout << world_rank << " pack_forward x_dim*n: " << x_dim * n - << std::endl; - double Msend = static_cast(x_dim * n * 4) / (1024 * 1024); - std::cout << world_rank << " send size(MB): " << Msend << "\n" << std::endl; - } - return x_dim * n; -} - -void PairE3GNNParallel::unpack_forward_comm_gnn(float *buf, int comm_phase) { - std::vector &idx_map = comm_index_unpack_forward[comm_phase]; - const int n = static_cast(idx_map.size()); - - if (use_cuda_mpi && n != 0) { - torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase]; - auto buf_tensor = - torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device)); - x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}), - buf_tensor); - } else { - int i, j, m; - m = 0; - for (i = 0; i < n; i++) { - const int idx = static_cast(idx_map.at(i)); - float *to = x_comm[idx].data_ptr(); - for (j = 0; j < x_dim; j++) { - to[j] = buf[m++]; - } - } - } -} - -int PairE3GNNParallel::pack_reverse_comm_gnn(float *buf, int comm_phase) { - std::vector &idx_map = comm_index_unpack_forward[comm_phase]; - const int n = static_cast(idx_map.size()); - - if (use_cuda_mpi && n != 0) { - torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase]; - auto selected = x_comm.index_select(0, idx_map_tensor); - cudaError_t cuda_err = cudaMemcpy(buf, selected.data_ptr(), (x_dim * n) * sizeof(float), cudaMemcpyDeviceToDevice); - } else { - int i, j, m; - m = 0; - for (i = 0; i < n; i++) { - const int idx = static_cast(idx_map.at(i)); - float *from = x_comm[idx].data_ptr(); - for (j = 0; j < x_dim; j++) { - buf[m++] = from[j]; - } - } - } - if (print_info) { - std::cout << world_rank << " comm_phase: " << comm_phase << std::endl; - std::cout << world_rank << " pack_reverse x_dim: " << x_dim << std::endl; - std::cout << world_rank << " pack_reverse n: " << n << std::endl; - std::cout << world_rank << " pack_reverse x_dim*n: " << x_dim * n - << std::endl; - double Msend = static_cast(x_dim * n * 4) / (1024 * 1024); - } - return x_dim * n; -} - -void PairE3GNNParallel::unpack_reverse_comm_gnn(float *buf, int comm_phase) { - std::vector &idx_map = comm_index_unpack_reverse[comm_phase]; - const int n = static_cast(idx_map.size()); - - if (use_cuda_mpi && n != 0) { - torch::Tensor &idx_map_tensor = comm_index_unpack_reverse_tensor[comm_phase]; - auto buf_tensor = - torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device)); - x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}), - buf_tensor, "add"); - } else { - int i, j, m; - m = 0; - for (i = 0; i < n; i++) { - const int idx = static_cast(idx_map.at(i)); - if (idx == -1) { - m += x_dim; - continue; - } - float *to = x_comm[idx].data_ptr(); - for (j = 0; j < x_dim; j++) { - to[j] += buf[m++]; - } - } - } -} diff --git a/sevenn/pair_e3gnn/patch_lammps.sh b/sevenn/pair_e3gnn/patch_lammps.sh index e1b7eec9..ec46c1ba 100755 --- a/sevenn/pair_e3gnn/patch_lammps.sh +++ b/sevenn/pair_e3gnn/patch_lammps.sh @@ -5,7 +5,6 @@ cxx_standard=$2 # 14, 17 d3_support=$3 # 1, 0 flashTP_so="${4:-NONE}" oeq_so="${5:-NONE}" -atomic_stress="${6:-0}" # 1, 0 SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") enable_flashTP=0 enable_oeq=0 @@ -15,8 +14,8 @@ enable_oeq=0 ########################################### # Check the number of arguments -if [ "$#" -lt 3 ] || [ "$#" -gt 6 ]; then - echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support} {flashTP_so} {oeq_so} {atomic_stress}" +if [ "$#" -lt 3 ] || [ "$#" -gt 5 ]; then + echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support} {flashTP_so} {oeq_so}" echo " {lammps_root}: Root directory of LAMMPS source" echo " {cxx_standard}: C++ standard (14, 17)" echo " {d3_support}: Support for pair_d3 (1, 0)" @@ -120,13 +119,7 @@ cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt ########################################### # 1. Copy pair_e3gnn files to LAMMPS source -if [ "$atomic_stress" -eq 1 ]; then -cp $SCRIPT_DIR/pair_e3gnn_atomic_stress.cpp $lammps_root/src/pair_e3gnn.cpp -cp $SCRIPT_DIR/pair_e3gnn_parallel_atomic_stress.cpp $lammps_root/src/pair_e3gnn_parallel.cpp -cp $SCRIPT_DIR/comm_brick.cpp $lammps_root/src/ -else cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ -fi cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ # Always copy the oEq autograd bridge (pair_e3gnn.cpp has an extern reference to it) cp $SCRIPT_DIR/pair_e3gnn_oeq_autograd.cpp $lammps_root/src/ # TODO: set this as oeq-specific @@ -206,9 +199,6 @@ fi echo "Changes made:" echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" echo " - Copied contents of pair_e3gnn to $lammps_root/src/" -if [ "$atomic_stress" -eq 1 ]; then - echo " - Atomic stress patch mode enabled: using pair_e3gnn*_atomic_stress.cpp" -fi echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" echo if [ "$enable_flashTP" -eq 1 ]; then diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index 573ccc78..aa7d62c8 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -19,7 +19,6 @@ def deploy( modal: Optional[str] = None, use_flash: bool = False, use_oeq: bool = False, - atomic_virial: bool = False, ) -> None: cp = load_checkpoint(checkpoint) model, config = ( @@ -68,7 +67,6 @@ def deploy( md_configs.update({'version': __version__}) md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) - md_configs.update({'atomic_virial': 'yes' if atomic_virial else 'no'}) if fname.endswith('.pt') is False: fname += '.pt' @@ -170,79 +168,3 @@ def deploy_parallel( model = torch.jit.freeze(model) torch.jit.save(model, fname_full, _extra_files=md_configs) - - -def deploy_ts( - checkpoint: Union[pathlib.Path, str], - fname='deployed_model.pt', - modal: Optional[str] = None, - use_flash: bool = False, - use_oeq: bool = False, - atomic_virial: bool = False, -) -> None: - ''' - only for SevenNetCalculator with torchscript input (not for e3gnn) - ''' - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ( - ForceStressOutput, - ForceStressOutputFromEdge, - ) - - cp = load_checkpoint(checkpoint) - - model, config = ( - cp.build_model( - enable_cueq=False, - enable_flash=use_flash, - enable_oeq=use_oeq, - _flash_lammps=use_flash, - ), - cp.config, - ) - - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key - - if hasattr(model, 'eval_type_map'): - setattr(model, 'eval_type_map', False) - - if modal: - model.prepare_modal_deploy(modal) - elif model.modal_map is not None and len(model.modal_map) >= 1: - raise ValueError( - f'Modal is not given. It has: {list(model.modal_map.keys())}' - ) - - model.set_is_batch_data(False) - model.eval() - - model = e3nn.util.jit.script(model) - model = torch.jit.freeze(model) - - # make some config need for md - md_configs = {} - type_map = config[KEY.TYPE_MAP] - chem_list = '' - for Z in type_map.keys(): - chem_list += chemical_symbols[Z] + ' ' - chem_list.strip() - md_configs.update({'chemical_symbols_to_index': chem_list}) - md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) - md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) - md_configs.update({'flashTP': 'yes' if use_flash else 'no'}) - md_configs.update({'oeq': 'yes' if use_oeq else 'no'}) - md_configs.update( - {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} - ) - md_configs.update({'version': __version__}) - md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) - md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) - md_configs.update({'atomic_virial': 'yes' if atomic_virial else 'no'}) - - if fname.endswith('.pt') is False: - fname += '.pt' - torch.jit.save(model, fname, _extra_files=md_configs) diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index 502fd3e1..a74f4c8e 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -18,7 +18,7 @@ from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available from sevenn.nn.oeq_helper import is_oeq_available -from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts +from sevenn.scripts.deploy import deploy, deploy_parallel from sevenn.util import chemical_species_preprocess, pretrained_name_to_path logger = logging.getLogger('test_lammps') @@ -110,19 +110,18 @@ def ref_modal_calculator(): @pytest.fixture(scope='module') -def ref_stress_calculator(tmp_path_factory): - tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') - pot_path = str(tmp / 'deployed_atomic_virial.pt') - deploy_ts(cp_0_path, pot_path, atomic_virial=True) - return SevenNetCalculator(pot_path, file_type='torchscript') +def ref_stress_calculator(): + return SevenNetCalculator(cp_0_path, atomic_virial=True) @pytest.fixture(scope='module') -def ref_7net0_stress_calculator(tmp_path_factory): - tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') - pot_path = str(tmp / 'deployed_atomic_virial.pt') - deploy_ts(cp_7net0_path, pot_path, atomic_virial=True) - return SevenNetCalculator(pot_path, file_type='torchscript') +def ref_7net0_stress_calculator(): + return SevenNetCalculator(cp_7net0_path, atomic_virial=True) + + +@pytest.fixture(scope='module') +def ref_modal_stress_calculator(): + return SevenNetCalculator(cp_mf_path, modal='PBE', atomic_virial=True) def get_model_config(): @@ -502,6 +501,25 @@ def test_serial_stress_flash( assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_modal_serial_stress( + system, serial_modal_potential_path, ref_modal_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_modal_potential_path, + wd=tmp_path, + test_name='modal serial lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_modal_stress_calculator + assert_atoms(atoms, atoms_lammps) + + @pytest.mark.parametrize( 'system,ncores', [ diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py deleted file mode 100644 index 3bc56237..00000000 --- a/tests/unit_tests/test_atomic_virial.py +++ /dev/null @@ -1,61 +0,0 @@ -import numpy as np -from ase.build import bulk - -import sevenn._keys as KEY -from sevenn.calculator import SevenNetCalculator -from sevenn.scripts.deploy import deploy_ts -from sevenn.util import pretrained_name_to_path - - -def _get_atoms_pbc(): - atoms = bulk('NaCl', 'rocksalt', a=5.63) - atoms.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms - - -def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): - model_path = str(tmp_path / '7net_0_atomic_virial.pt') - deploy_ts(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) - - calc = SevenNetCalculator(model_path, file_type='torchscript') - atoms = _get_atoms_pbc() - atoms.calc = calc - _ = atoms.get_potential_energy() - - assert 'stresses' in calc.results - atomic_virial = np.asarray(calc.results['stresses']) - - assert atomic_virial.shape == (len(atoms), 6) - assert np.isfinite(atomic_virial).all() - assert np.any(np.abs(atomic_virial) > 0.0) - # Full 6-component per-atom reference check. - # Sort rows for deterministic comparison when atom-wise ordering changes. - atomic_virial_ref = np.array( - [ - [10.06461800, -0.07430478, -0.07463801, -0.38235345, -0.04390856, -0.47967120], - [13.55999100, 1.41123860, 1.41158500, -1.28466740, -0.17591128, -1.18766780], - ] - ) - atomic_virial_sorted = atomic_virial[np.argsort(atomic_virial[:, 0])] - atomic_virial_ref_sorted = atomic_virial_ref[np.argsort(atomic_virial_ref[:, 0])] - assert np.allclose(atomic_virial_sorted, atomic_virial_ref_sorted, atol=1e-4) - - # Internal model stress ordering is [xx, yy, zz, xy, yz, zx]. - # Calculator exposes ASE stress as -stress_internal with - # component order [xx, yy, zz, yz, zx, xy]. - virial_sum = atomic_virial.sum(axis=0) - virial_sum_ref = np.array([ - 23.62459886, - 1.33693361, - 1.33694608, - -1.66702306, - -0.21981908, - -1.66734152, - ]) - assert np.allclose(virial_sum, virial_sum_ref, atol=1e-4) - - stress_internal_from_virial = virial_sum / atoms.get_volume() - stress_ase_from_virial = -stress_internal_from_virial[[0, 1, 2, 4, 5, 3]] - - assert np.allclose(calc.results['stress'], stress_ase_from_virial, atol=1e-5) diff --git a/tests/unit_tests/test_calculator.py b/tests/unit_tests/test_calculator.py index e4aacf2e..6b11b717 100644 --- a/tests/unit_tests/test_calculator.py +++ b/tests/unit_tests/test_calculator.py @@ -8,7 +8,6 @@ from sevenn.calculator import D3Calculator, SevenNetCalculator from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available -from sevenn.scripts.deploy import deploy_ts from sevenn.util import model_from_checkpoint, pretrained_name_to_path @@ -105,26 +104,6 @@ def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies']) -def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): - atoms_pbc.rattle(stdev=0.01, seed=42) - fname = str(tmp_path / '7net_0.pt') - deploy_ts(pretrained_name_to_path('7net-0_11July2024'), fname) - - calc_script = SevenNetCalculator(fname, file_type='torchscript') - calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) - - atoms_pbc.calc = calc_cp - atoms_pbc.get_potential_energy() - res_cp = copy.copy(atoms_pbc.calc.results) - - atoms_pbc.calc = calc_script - atoms_pbc.get_potential_energy() - res_script = copy.copy(atoms_pbc.calc.results) - - for k in res_cp: - assert np.allclose(res_cp[k], res_script[k], rtol=1e-4, atol=1e-4) - - def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): atoms_pbc.rattle(stdev=0.01, seed=42) model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024')) From a9700626675ea4cbc5670a7b3e9a8bf585bfd220 Mon Sep 17 00:00:00 2001 From: lee2tae Date: Fri, 27 Mar 2026 17:15:38 +0900 Subject: [PATCH 5/7] bugfix --- sevenn/main/sevenn_get_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index 5dd373ce..d98eb21d 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -78,7 +78,6 @@ def run(args): use_cueq = args.enable_cueq use_oeq = args.enable_oeq use_mliap = args.use_mliap - atomic_virial = args.atomic_virial # Check dependencies if use_flash: From aae23970a8fe1f3674542f7f2ea0113a2f8aed29 Mon Sep 17 00:00:00 2001 From: lee2tae Date: Fri, 27 Mar 2026 17:35:59 +0900 Subject: [PATCH 6/7] calculator does not support torchscipt input anymore --- sevenn/calculator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 06c52163..9bb70665 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -52,7 +52,7 @@ def __init__( Name of pretrained models (7net-omni, 7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or path to the checkpoint, deployed model or the model itself file_type: str, default='checkpoint' - one of 'checkpoint' | 'torchscript' | 'model_instance' + one of 'checkpoint' | 'model_instance' device: str | torch.device, default='auto' if not given, use CUDA if available modal: str | None, default=None @@ -81,7 +81,7 @@ def __init__( if isinstance(model, pathlib.PurePath): model = str(model) - allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] + allowed_file_types = ['checkpoint', 'model_instance'] file_type = file_type.lower() if file_type not in allowed_file_types: raise ValueError(f'file_type not in {allowed_file_types}') From 8d8ce6d8fcb6e400fe43a8726dba769f44ba6865 Mon Sep 17 00:00:00 2001 From: LEE EUI TAE Date: Fri, 27 Mar 2026 21:43:14 +0900 Subject: [PATCH 7/7] clean --- sevenn/calculator.py | 10 +++++++--- sevenn/main/sevenn_get_model.py | 2 +- tests/lammps_tests/test_lammps.py | 17 ++++++++++++++--- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 9bb70665..81ad325c 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -173,8 +173,9 @@ def __init__( if isinstance(self.model, AtomGraphSequential): force_output = self.model._modules.get('force_output') if force_output is not None: - self.atomic_virial_from_deploy = self.atomic_virial_from_deploy or bool( - getattr(force_output, 'use_atomic_virial', False) + self.atomic_virial_from_deploy = ( + self.atomic_virial_from_deploy + or bool(getattr(force_output, 'use_atomic_virial', False)) ) self.modal = None @@ -264,7 +265,10 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): del data['data_info'] elif self.atomic_virial_requested: force_output = self.model._modules.get('force_output') - if force_output is not None and hasattr(force_output, 'use_atomic_virial'): + if ( + force_output is not None + and hasattr(force_output, 'use_atomic_virial') + ): setattr(force_output, 'use_atomic_virial', True) output = self.model(data) diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index d98eb21d..1785d18b 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -121,7 +121,7 @@ def run(args): from sevenn.scripts.deploy import deploy, deploy_parallel if get_serial: - deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) + deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index a74f4c8e..ec213084 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -467,7 +467,11 @@ def test_serial_stress( ['bulk', 'surface'], ) def test_serial_stress_oeq( - system, serial_potential_path_oeq, ref_7net0_stress_calculator, lammps_cmd, tmp_path + system, + serial_potential_path_oeq, + ref_7net0_stress_calculator, + lammps_cmd, + tmp_path ): atoms = get_system(system) atoms_lammps = serial_stress_lammps_run( @@ -487,7 +491,11 @@ def test_serial_stress_oeq( ['bulk', 'surface'], ) def test_serial_stress_flash( - system, serial_potential_path_flash, ref_7net0_stress_calculator, lammps_cmd, tmp_path + system, + serial_potential_path_flash, + ref_7net0_stress_calculator, + lammps_cmd, + tmp_path ): atoms = get_system(system) atoms_lammps = serial_stress_lammps_run( @@ -506,7 +514,10 @@ def test_serial_stress_flash( ['bulk', 'surface'], ) def test_modal_serial_stress( - system, serial_modal_potential_path, ref_modal_stress_calculator, lammps_cmd, tmp_path + system, + serial_modal_potential_path, + ref_modal_stress_calculator, + lammps_cmd, tmp_path ): atoms = get_system(system) atoms_lammps = serial_stress_lammps_run(