diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 0c9af7b7..4b8d63d7 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -65,6 +65,7 @@ SCALED_FORCE: Final[str] = 'scaled_force' PRED_STRESS: Final[str] = 'inferred_stress' +PRED_ATOMIC_VIRIAL: Final[str] = 'inferred_atomic_virial' SCALED_STRESS: Final[str] = 'scaled_stress' # very general data property for AtomGraphData diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 2f4a3d59..06c52163 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -40,6 +40,7 @@ def __init__( enable_cueq: Optional[bool] = False, enable_flash: Optional[bool] = False, enable_oeq: Optional[bool] = False, + atomic_virial: bool = False, sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info **kwargs, ) -> None: @@ -69,9 +70,13 @@ def __init__( if True, use OpenEquivariance to accelerate inference. sevennet_config: dict | None, default=None Not used, but can be used to carry meta information of this calculator + atomic_virial: bool, default=False + If True, request per-atom virial output (`stresses`) at runtime. """ super().__init__(**kwargs) self.sevennet_config = None + self.atomic_virial_requested = atomic_virial + self.atomic_virial_from_deploy = False if isinstance(model, pathlib.PurePath): model = str(model) @@ -131,6 +136,7 @@ def __init__( 'version': b'', 'dtype': b'', 'time': b'', + 'atomic_virial': b'', } model_loaded = torch.jit.load( model, _extra_files=extra_dict, map_location=self.device @@ -141,6 +147,9 @@ def __init__( sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) } self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) + self.atomic_virial_from_deploy = ( + extra_dict['atomic_virial'].decode('utf-8') == 'yes' + ) elif isinstance(model, AtomGraphSequential): if model.type_map is None: @@ -161,6 +170,12 @@ def __init__( self.sevennet_config = sevennet_config self.model = model_loaded + if isinstance(self.model, AtomGraphSequential): + force_output = self.model._modules.get('force_output') + if force_output is not None: + self.atomic_virial_from_deploy = self.atomic_virial_from_deploy or bool( + getattr(force_output, 'use_atomic_virial', False) + ) self.modal = None if isinstance(self.model, AtomGraphSequential): @@ -182,6 +197,7 @@ def __init__( 'energy', 'forces', 'stress', + 'stresses', 'energies', ] @@ -206,8 +222,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: .cpu() .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation ) - # Store results - return { + results: Dict[str, Any] = { 'free_energy': energy, 'energy': energy, 'energies': atomic_energies, @@ -215,6 +230,12 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: 'stress': stress, 'num_edges': output[KEY.EDGE_IDX].shape[1], } + if KEY.PRED_ATOMIC_VIRIAL in output: + virial = ( + output[KEY.PRED_ATOMIC_VIRIAL].detach().cpu().numpy()[:num_atoms, :] + ) + results['stresses'] = virial + return results def calculate(self, atoms=None, properties=None, system_changes=all_changes): is_ts_type = isinstance(self.model, torch_script_type) @@ -241,8 +262,13 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data = data.to_dict() del data['data_info'] + elif self.atomic_virial_requested: + force_output = self.model._modules.get('force_output') + if force_output is not None and hasattr(force_output, 'use_atomic_virial'): + setattr(force_output, 'use_atomic_virial', True) - self.results = self.output_to_results(self.model(data)) + output = self.model(data) + self.results = self.output_to_results(output) class SevenNetD3Calculator(SumCalculator): diff --git a/sevenn/main/sevenn_get_model.py b/sevenn/main/sevenn_get_model.py index 1785d18b..ac3193fd 100644 --- a/sevenn/main/sevenn_get_model.py +++ b/sevenn/main/sevenn_get_model.py @@ -64,6 +64,15 @@ def add_args(parser): help='Use LAMMPS ML-IAP interface.', action='store_true', ) + ag.add_argument( + '--atomic_virial', + help=( + 'Serial deploy only: append per-atom virial output ' + '(inferred_atomic_virial) to TorchScript. This marks model-side ' + 'atomic virial capability for downstream calculator usage.' + ), + action='store_true', + ) def run(args): @@ -78,6 +87,7 @@ def run(args): use_cueq = args.enable_cueq use_oeq = args.enable_oeq use_mliap = args.use_mliap + atomic_virial = args.atomic_virial # Check dependencies if use_flash: @@ -104,9 +114,14 @@ def run(args): if use_mliap and get_parallel: raise ValueError('Currently, ML-IAP interface does not tested on parallel.') + if atomic_virial and (use_mliap or get_parallel): + raise ValueError('--atomic_virial is only supported for serial deployment.') + # deploy if output_prefix is None: output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' + if atomic_virial: + output_prefix = 'deployed_model' if use_mliap: output_prefix += '_mliap' @@ -118,10 +133,24 @@ def run(args): checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) if not use_mliap: - from sevenn.scripts.deploy import deploy, deploy_parallel - - if get_serial: - deploy(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 + from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts + + if atomic_virial: + deploy_ts( + checkpoint_path, + output_prefix, + modal, + use_flash=use_flash, + use_oeq=use_oeq, + atomic_virial=atomic_virial, + ) + elif get_serial: + deploy( + checkpoint_path, + output_prefix, + modal, + use_flash=use_flash, + use_oeq=use_oeq) else: deploy_parallel(checkpoint_path, output_prefix, modal, use_flash=use_flash, use_oeq=use_oeq) # noqa: E501 else: diff --git a/sevenn/main/sevenn_patch_lammps.py b/sevenn/main/sevenn_patch_lammps.py index e90b1c78..3ace2586 100644 --- a/sevenn/main/sevenn_patch_lammps.py +++ b/sevenn/main/sevenn_patch_lammps.py @@ -36,6 +36,11 @@ def add_args(parser): help='Enable OpenEquivariance', action='store_true', ) + ag.add_argument( + '--atomic_stress', + help='Patch pair_e3gnn with atomic-stress enabled source files.', + action='store_true', + ) # cxx_standard is detected automatically @@ -54,6 +59,12 @@ def run(args): d3_support = '0' print(' - D3 support disabled') + atomic_stress = '1' if args.atomic_stress else '0' + if args.atomic_stress: + print(' - Atomic stress patch enabled') + else: + print(' - Atomic stress patch disabled') + so_oeq = '' if args.enable_oeq: try: @@ -111,6 +122,10 @@ def run(args): if args.enable_oeq: assert osp.isfile(so_oeq) cmd += f' {so_oeq}' + else: + cmd += ' NONE' + + cmd += f' {atomic_stress}' res = subprocess.run(cmd.split()) return res.returncode # is it meaningless? diff --git a/sevenn/nn/force_output.py b/sevenn/nn/force_output.py index cd1cc816..0ee2073b 100644 --- a/sevenn/nn/force_output.py +++ b/sevenn/nn/force_output.py @@ -149,7 +149,9 @@ def __init__( data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_force: str = KEY.PRED_FORCE, data_key_stress: str = KEY.PRED_STRESS, + data_key_atomic_virial: str = KEY.PRED_ATOMIC_VIRIAL, data_key_cell_volume: str = KEY.CELL_VOLUME, + use_atomic_virial: bool = False, ) -> None: super().__init__() @@ -158,7 +160,9 @@ def __init__( self.key_energy = data_key_energy self.key_force = data_key_force self.key_stress = data_key_stress + self.key_atomic_virial = data_key_atomic_virial self.key_cell_volume = data_key_cell_volume + self.use_atomic_virial = use_atomic_virial self._is_batch_data = True def get_grad_key(self) -> str: @@ -206,6 +210,8 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) _edge_dst6 = broadcast(edge_idx[1], _virial, 0) _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') + if self.use_atomic_virial: + data[self.key_atomic_virial] = torch.neg(_s) if self._is_batch_data: batch = data[KEY.BATCH] # for deploy, must be defined first diff --git a/sevenn/pair_e3gnn/pair_e3gnn.cpp b/sevenn/pair_e3gnn/pair_e3gnn.cpp index 3e99b80b..94942c6a 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn.cpp @@ -67,7 +67,7 @@ PairE3GNN::~PairE3GNN() { memory->destroy(setflag); memory->destroy(cutsq); memory->destroy(map); - memory->destroy(elements); + // memory->destroy(elements); } } @@ -77,42 +77,32 @@ void PairE3GNN::compute(int eflag, int vflag) { This compute function is ispired/modified from stress branch of pair-nequip https://github.com/mir-group/pair_nequip */ - if (eflag || vflag) ev_setup(eflag, vflag); else evflag = vflag_fdotr = 0; - if (vflag_atom) { - error->all(FLERR, "atomic stress is not supported\n"); - } - - int nlocal = list->inum; // same as nlocal - int *ilist = list->ilist; - tagint *tag = atom->tag; - std::unordered_map tag_map; + // if (vflag_atom) { + // error->all(FLERR, "atomic stress is not supported\n"); + // } if (atom->tag_consecutive() == 0) { - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag[i]; - tag_map[itag] = ii+1; - // printf("MODIFY setting %i => %i \n",itag, tag_map[itag] ); - } - } else { - //Ordered which mappling required - for (int ii = 0; ii < nlocal; ii++) { - const int itag = ilist[ii]+1; - tag_map[itag] = ii+1; - // printf("normal setting %i => %i \n",itag, tag_map[itag] ); - } + error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); } double **x = atom->x; double **f = atom->f; int *type = atom->type; - long num_atoms[1] = {nlocal}; + int nlocal = list->inum; // same as nlocal, Why? is it different from atom->nlocal? + int *ilist = list->ilist; + int inum = list->inum; - int tag2i[nlocal]; + // tag ignore PBC + tagint *tag = atom->tag; + + std::unordered_map tag_map; + std::vector graph_index_to_i(nlocal); + + long num_atoms[1] = {nlocal}; int *numneigh = list->numneigh; // j loop cond int **firstneigh = list->firstneigh; // j list @@ -125,78 +115,64 @@ void PairE3GNN::compute(int eflag, int vflag) { } const int nedges_upper_bound = bound; - float cell[3][3]; - cell[0][0] = domain->boxhi[0] - domain->boxlo[0]; - cell[0][1] = 0.0; - cell[0][2] = 0.0; - - cell[1][0] = domain->xy; - cell[1][1] = domain->boxhi[1] - domain->boxlo[1]; - cell[1][2] = 0.0; - - cell[2][0] = domain->xz; - cell[2][1] = domain->yz; - cell[2][2] = domain->boxhi[2] - domain->boxlo[2]; - - torch::Tensor inp_cell = torch::from_blob(cell, {3, 3}, FLOAT_TYPE); - torch::Tensor inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); - - torch::Tensor inp_node_type = torch::zeros({nlocal}, INTEGER_TYPE); - torch::Tensor inp_pos = torch::zeros({nlocal, 3}); - - torch::Tensor inp_cell_volume = - torch::dot(inp_cell[0], torch::cross(inp_cell[1], inp_cell[2], 0)); - - float pbc_shift_tmp[nedges_upper_bound][3]; - - auto node_type = inp_node_type.accessor(); - auto pos = inp_pos.accessor(); + std::vector node_type; + float edge_vec[nedges_upper_bound][3]; long edge_idx_src[nedges_upper_bound]; long edge_idx_dst[nedges_upper_bound]; - int nedges = 0; - - for (int ii = 0; ii < nlocal; ii++) { + for (int ii = 0; ii < inum; ii++) { + // populate tag_map of local atoms const int i = ilist[ii]; - int itag = tag_map[tag[i]]; - tag2i[itag - 1] = i; + const int itag = tag[i]; const int itype = type[i]; - node_type[itag - 1] = map[itype]; - pos[itag - 1][0] = x[i][0]; - pos[itag - 1][1] = x[i][1]; - pos[itag - 1][2] = x[i][2]; + tag_map[itag] = ii; + graph_index_to_i[ii] = i; + node_type.push_back(map[itype]); } - for (int ii = 0; ii < nlocal; ii++) { + int nedges = 0; + // loop over neighbors, build graph + for (int ii = 0; ii < inum; ii++) { const int i = ilist[ii]; - int itag = tag_map[tag[i]]; + const int i_graph_idx = ii; const int *jlist = firstneigh[i]; const int jnum = numneigh[i]; for (int jj = 0; jj < jnum; jj++) { - int j = jlist[jj]; // atom over pbc is different atom - int jtag = tag_map[tag[j]]; // atom over pbs is same atom (it starts from 1) + int j = jlist[jj]; + const int jtag = tag[j]; j &= NEIGHMASK; - const int jtype = type[j]; + const auto found = tag_map.find(jtag); + if (found == tag_map.end()) continue; + const int j_graph_idx = found->second; + + // we have to calculate Rij to check cutoff in lammps side const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], x[j][2] - x[i][2]}; const double Rij = delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; - if (Rij < cutoff_square) { - edge_idx_src[nedges] = itag - 1; - edge_idx_dst[nedges] = jtag - 1; - - pbc_shift_tmp[nedges][0] = x[j][0] - pos[jtag - 1][0]; - pbc_shift_tmp[nedges][1] = x[j][1] - pos[jtag - 1][1]; - pbc_shift_tmp[nedges][2] = x[j][2] - pos[jtag - 1][2]; + if (Rij < cutoff_square) { + // if given j is not inside cutoff + if (nedges >= nedges_upper_bound) { + error->all(FLERR, "nedges exceeded nedges_upper_bound"); + } + edge_idx_src[nedges] = i_graph_idx; + edge_idx_dst[nedges] = j_graph_idx; + edge_vec[nedges][0] = delij[0]; + edge_vec[nedges][1] = delij[1]; + edge_vec[nedges][2] = delij[2]; nedges++; } } // j loop end } // i loop end + // convert data to Tensor + auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE); + auto inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); + auto edge_idx_src_tensor = torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); auto edge_idx_dst_tensor = @@ -204,66 +180,102 @@ void PairE3GNN::compute(int eflag, int vflag) { auto inp_edge_index = torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); - // r' = r + {shift_tensor(integer vector of len 3)} @ cell_tensor - // shift_tensor = (cell_tensor)^-1^T @ (r' - r) - torch::Tensor cell_inv_tensor = - inp_cell.inverse().transpose(0, 1).unsqueeze(0).to(device); - torch::Tensor pbc_shift_tmp_tensor = - torch::from_blob(pbc_shift_tmp, {nedges, 3}, FLOAT_TYPE) - .view({nedges, 3, 1}) - .to(device); - torch::Tensor inp_cell_shift = - torch::bmm(cell_inv_tensor.expand({nedges, 3, 3}), pbc_shift_tmp_tensor) - .view({nedges, 3}); - - inp_pos.set_requires_grad(true); - - c10::Dict input_dict; + auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE); + if (print_info) { + std::cout << " Nlocal: " << nlocal << std::endl; + std::cout << " Nedges: " << nedges << "\n" << std::endl; + } + + auto edge_vec_device = inp_edge_vec.to(device); + edge_vec_device.set_requires_grad(true); + + torch::Dict input_dict; input_dict.insert("x", inp_node_type.to(device)); - input_dict.insert("pos", inp_pos.to(device)); input_dict.insert("edge_index", inp_edge_index.to(device)); + input_dict.insert("edge_vec", edge_vec_device); input_dict.insert("num_atoms", inp_num_atoms.to(device)); - input_dict.insert("cell_lattice_vectors", inp_cell.to(device)); - input_dict.insert("cell_volume", inp_cell_volume.to(device)); - input_dict.insert("pbc_shift", inp_cell_shift); + input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU)); std::vector input(1, input_dict); auto output = model.forward(input).toGenericDict(); - torch::Tensor total_energy_tensor = - output.at("inferred_total_energy").toTensor().cpu(); - torch::Tensor force_tensor = output.at("inferred_force").toTensor().cpu(); + torch::Tensor energy_tensor = + output.at("inferred_total_energy").toTensor().squeeze(); + + // dE_dr + auto grads = torch::autograd::grad({energy_tensor}, {edge_vec_device}); + torch::Tensor dE_dr = grads[0].to(torch::kCPU); + + eng_vdwl += energy_tensor.detach().to(torch::kCPU).item(); + torch::Tensor force_tensor = torch::zeros({nlocal, 3}); + + auto _edge_idx_src_tensor = + edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}); + auto _edge_idx_dst_tensor = + edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}); + + force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum"); + force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr), + "sum"); + auto forces = force_tensor.accessor(); - eng_vdwl += total_energy_tensor.item(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - f[i][0] += forces[itag][0]; - f[i][1] += forces[itag][1]; - f[i][2] += forces[itag][2]; + for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) { + int i = graph_index_to_i[graph_idx]; + f[i][0] += forces[graph_idx][0]; + f[i][1] += forces[graph_idx][1]; + f[i][2] += forces[graph_idx][2]; } + // Virial stress from edge contributions if (vflag) { - // more accurately, it is virial part of stress - torch::Tensor stress_tensor = output.at("inferred_stress").toTensor().cpu(); - auto virial_stress_tensor = stress_tensor * inp_cell_volume; - // xy yz zx order in vasp (voigt is xx yy zz yz xz xy) + auto diag = inp_edge_vec * dE_dr; + auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1); + auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2); + auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0); + std::vector voigt_list = { + diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)}; + auto voigt = torch::cat(voigt_list, 1); + + torch::Tensor per_atom_stress_tensor = torch::zeros({nlocal, 6}); + auto _edge_idx_dst6_tensor = + edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6}); + per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt, + "sum"); + + auto virial_stress_tensor = + torch::neg(torch::sum(per_atom_stress_tensor, 0)); auto virial_stress = virial_stress_tensor.accessor(); + virial[0] += virial_stress[0]; virial[1] += virial_stress[1]; virial[2] += virial_stress[2]; virial[3] += virial_stress[3]; virial[4] += virial_stress[5]; virial[5] += virial_stress[4]; + + if (vflag_atom) { + auto per_atom_stress = per_atom_stress_tensor.accessor(); + + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + vatom[i][0] += -per_atom_stress[gi][0]; + vatom[i][1] += -per_atom_stress[gi][1]; + vatom[i][2] += -per_atom_stress[gi][2]; + vatom[i][3] += -per_atom_stress[gi][3]; + vatom[i][4] += -per_atom_stress[gi][5]; + vatom[i][5] += -per_atom_stress[gi][4]; + } + } } if (eflag_atom) { torch::Tensor atomic_energy_tensor = - output.at("atomic_energy").toTensor().cpu().view({nlocal}); + output.at("atomic_energy").toTensor().to(torch::kCPU).view({nlocal}); auto atomic_energy = atomic_energy_tensor.accessor(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - eatom[i] += atomic_energy[itag]; + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + eatom[i] += atomic_energy[gi]; } } diff --git a/sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp b/sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp new file mode 100644 index 00000000..ea161337 --- /dev/null +++ b/sevenn/pair_e3gnn/pair_e3gnn_atomic_stress.cpp @@ -0,0 +1,444 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://lammps.sandia.gov/, Sandia National Laboratories + Steve Plimpton, sjplimp@sandia.gov + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Yutack Park (SNU) +------------------------------------------------------------------------- */ + +#include +#include +#include +#include + +#include +#include + +#include "atom.h" +#include "domain.h" +#include "error.h" +#include "force.h" +#include "memory.h" +#include "neigh_list.h" +#include "neigh_request.h" +#include "neighbor.h" + +#include "pair_e3gnn.h" + +using namespace LAMMPS_NS; + +// Undefined reference; body in pair_e3gnn_oeq_autograd.cpp to be linked +extern void pair_e3gnn_oeq_register_autograd(); + +#define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64) +#define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat) + +PairE3GNN::PairE3GNN(LAMMPS *lmp) : Pair(lmp) { + // constructor + // Virial is accumulated directly from the model; skip LAMMPS fdotr post-pass. + no_virial_fdotr_compute = 1; + + const char *print_flag = std::getenv("SEVENN_PRINT_INFO"); + if (print_flag) + print_info = true; + + std::string device_name; + if (torch::cuda::is_available()) { + device = torch::kCUDA; + device_name = "CUDA"; + } else { + device = torch::kCPU; + device_name = "CPU"; + } + + if (lmp->logfile) { + fprintf(lmp->logfile, "PairE3GNN using device : %s\n", device_name.c_str()); + } +} + +PairE3GNN::~PairE3GNN() { + if (allocated) { + memory->destroy(setflag); + memory->destroy(cutsq); + memory->destroy(map); + memory->destroy(elements); + } +} + +void PairE3GNN::compute(int eflag, int vflag) { + // compute + /* + This compute function is ispired/modified from stress branch of pair-nequip + https://github.com/mir-group/pair_nequip + */ + + if (eflag || vflag) + ev_setup(eflag, vflag); + else + evflag = vflag_fdotr = 0; +// if (vflag_atom) { +// error->all(FLERR, "atomic stress is not supported\n"); +// } + + int nlocal = list->inum; // same as nlocal + int *ilist = list->ilist; + tagint *tag = atom->tag; + std::unordered_map tag_map; + + if (atom->tag_consecutive() == 0) { + for (int ii = 0; ii < nlocal; ii++) { + const int i = ilist[ii]; + int itag = tag[i]; + tag_map[itag] = ii+1; + // printf("MODIFY setting %i => %i \n",itag, tag_map[itag] ); + } + } else { + //Ordered which mappling required + for (int ii = 0; ii < nlocal; ii++) { + const int itag = ilist[ii]+1; + tag_map[itag] = ii+1; + // printf("normal setting %i => %i \n",itag, tag_map[itag] ); + } + } + + double **x = atom->x; + double **f = atom->f; + int *type = atom->type; + long num_atoms[1] = {nlocal}; + + int tag2i[nlocal]; + + int *numneigh = list->numneigh; // j loop cond + int **firstneigh = list->firstneigh; // j list + + int bound; + if (this->nedges_bound == -1) { + bound = std::accumulate(numneigh, numneigh + nlocal, 0); + } else { + bound = this->nedges_bound; + } + const int nedges_upper_bound = bound; + + float cell[3][3]; + cell[0][0] = domain->boxhi[0] - domain->boxlo[0]; + cell[0][1] = 0.0; + cell[0][2] = 0.0; + + cell[1][0] = domain->xy; + cell[1][1] = domain->boxhi[1] - domain->boxlo[1]; + cell[1][2] = 0.0; + + cell[2][0] = domain->xz; + cell[2][1] = domain->yz; + cell[2][2] = domain->boxhi[2] - domain->boxlo[2]; + + torch::Tensor inp_cell = torch::from_blob(cell, {3, 3}, FLOAT_TYPE); + torch::Tensor inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); + + torch::Tensor inp_node_type = torch::zeros({nlocal}, INTEGER_TYPE); + torch::Tensor inp_pos = torch::zeros({nlocal, 3}); + + torch::Tensor inp_cell_volume = + torch::dot(inp_cell[0], torch::cross(inp_cell[1], inp_cell[2], 0)); + + float pbc_shift_tmp[nedges_upper_bound][3]; + + auto node_type = inp_node_type.accessor(); + auto pos = inp_pos.accessor(); + + long edge_idx_src[nedges_upper_bound]; + long edge_idx_dst[nedges_upper_bound]; + + int nedges = 0; + + for (int ii = 0; ii < nlocal; ii++) { + const int i = ilist[ii]; + int itag = tag_map[tag[i]]; + tag2i[itag - 1] = i; + const int itype = type[i]; + node_type[itag - 1] = map[itype]; + pos[itag - 1][0] = x[i][0]; + pos[itag - 1][1] = x[i][1]; + pos[itag - 1][2] = x[i][2]; + } + + for (int ii = 0; ii < nlocal; ii++) { + const int i = ilist[ii]; + int itag = tag_map[tag[i]]; + const int *jlist = firstneigh[i]; + const int jnum = numneigh[i]; + + for (int jj = 0; jj < jnum; jj++) { + int j = jlist[jj]; // atom over pbc is different atom + int jtag = tag_map[tag[j]]; // atom over pbs is same atom (it starts from 1) + j &= NEIGHMASK; + const int jtype = type[j]; + + const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], + x[j][2] - x[i][2]}; + const double Rij = + delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; + if (Rij < cutoff_square) { + edge_idx_src[nedges] = itag - 1; + edge_idx_dst[nedges] = jtag - 1; + + pbc_shift_tmp[nedges][0] = x[j][0] - pos[jtag - 1][0]; + pbc_shift_tmp[nedges][1] = x[j][1] - pos[jtag - 1][1]; + pbc_shift_tmp[nedges][2] = x[j][2] - pos[jtag - 1][2]; + + nedges++; + } + } // j loop end + } // i loop end + + auto edge_idx_src_tensor = + torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); + auto edge_idx_dst_tensor = + torch::from_blob(edge_idx_dst, {nedges}, INTEGER_TYPE); + auto inp_edge_index = + torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); + + // r' = r + {shift_tensor(integer vector of len 3)} @ cell_tensor + // shift_tensor = (cell_tensor)^-1^T @ (r' - r) + torch::Tensor cell_inv_tensor = + inp_cell.inverse().transpose(0, 1).unsqueeze(0).to(device); + torch::Tensor pbc_shift_tmp_tensor = + torch::from_blob(pbc_shift_tmp, {nedges, 3}, FLOAT_TYPE) + .view({nedges, 3, 1}) + .to(device); + torch::Tensor inp_cell_shift = + torch::bmm(cell_inv_tensor.expand({nedges, 3, 3}), pbc_shift_tmp_tensor) + .view({nedges, 3}); + + inp_pos.set_requires_grad(true); + + c10::Dict input_dict; + input_dict.insert("x", inp_node_type.to(device)); + input_dict.insert("pos", inp_pos.to(device)); + input_dict.insert("edge_index", inp_edge_index.to(device)); + input_dict.insert("num_atoms", inp_num_atoms.to(device)); + input_dict.insert("cell_lattice_vectors", inp_cell.to(device)); + input_dict.insert("cell_volume", inp_cell_volume.to(device)); + input_dict.insert("pbc_shift", inp_cell_shift); + + std::vector input(1, input_dict); + auto output = model.forward(input).toGenericDict(); + + torch::Tensor total_energy_tensor = + output.at("inferred_total_energy").toTensor().cpu(); + torch::Tensor force_tensor = output.at("inferred_force").toTensor().cpu(); + auto forces = force_tensor.accessor(); + eng_vdwl += total_energy_tensor.item(); + + for (int itag = 0; itag < nlocal; itag++) { + int i = tag2i[itag]; + f[i][0] += forces[itag][0]; + f[i][1] += forces[itag][1]; + f[i][2] += forces[itag][2]; + } + + if (vflag) { + // more accurately, it is virial part of stress + torch::Tensor stress_tensor = output.at("inferred_stress").toTensor().cpu(); + bool has_atomic_virial = false; + torch::Tensor stress_virial_tensor; + try { + stress_virial_tensor = output.at("inferred_atomic_virial").toTensor().cpu(); + has_atomic_virial = true; + } catch (...) { + has_atomic_virial = false; + } + auto virial_stress_tensor = stress_tensor * inp_cell_volume; + if (vflag_atom && !has_atomic_virial) { + error->warning( + FLERR, + "Model output has no inferred_atomic_virial; " + "stress/atom consistency test cannot be completed"); + } + // xy yz zx order in vasp (voigt is xx yy zz yz xz xy) + auto virial_stress = virial_stress_tensor.accessor(); + virial[0] += virial_stress[0]; + virial[1] += virial_stress[1]; + virial[2] += virial_stress[2]; + virial[3] += virial_stress[3]; + virial[4] += virial_stress[5]; + virial[5] += virial_stress[4]; + + if (vflag_atom && has_atomic_virial) { + auto atomic_virial = stress_virial_tensor.accessor(); + for (int itag = 0; itag < nlocal; itag++) { + int i = tag2i[itag]; + vatom[i][0] += atomic_virial[itag][0]; + vatom[i][1] += atomic_virial[itag][1]; + vatom[i][2] += atomic_virial[itag][2]; + vatom[i][3] += atomic_virial[itag][3]; + vatom[i][4] += atomic_virial[itag][5]; + vatom[i][5] += atomic_virial[itag][4]; + } + } + } + + if (eflag_atom) { + torch::Tensor atomic_energy_tensor = + output.at("atomic_energy").toTensor().cpu().view({nlocal}); + auto atomic_energy = atomic_energy_tensor.accessor(); + for (int itag = 0; itag < nlocal; itag++) { + int i = tag2i[itag]; + eatom[i] += atomic_energy[itag]; + } + } + + // if it was the first MD step + if (this->nedges_bound == -1) { + this->nedges_bound = nedges * 1.2; + } // else if the nedges is too small, increase the bound + else if (nedges > this->nedges_bound / 1.2) { + this->nedges_bound = nedges * 1.2; + } +} + +// allocate arrays (called from coeff) +void PairE3GNN::allocate() { + allocated = 1; + int n = atom->ntypes; + + memory->create(setflag, n + 1, n + 1, "pair:setflag"); + memory->create(cutsq, n + 1, n + 1, "pair:cutsq"); + memory->create(map, n + 1, "pair:map"); +} + +// global settings for pair_style +void PairE3GNN::settings(int narg, char **arg) { + if (narg != 0) { + error->all(FLERR, "Illegal pair_style command"); + } +} + +void PairE3GNN::coeff(int narg, char **arg) { + + if (allocated) { + error->all(FLERR, "pair_e3gnn coeff called twice"); + } + allocate(); + + if (strcmp(arg[0], "*") != 0 || strcmp(arg[1], "*") != 0) { + error->all(FLERR, + "e3gnn: first and second input of pair_coeff should be '*'"); + } + // expected input : pair_coeff * * pot.pth type_name1 type_name2 ... + + std::unordered_map meta_dict = { + {"chemical_symbols_to_index", ""}, + {"cutoff", ""}, + {"num_species", ""}, + {"model_type", ""}, + {"version", ""}, + {"dtype", ""}, + {"flashTP", "version mismatch"}, + {"oeq", "version mismatch"}, + {"time", ""}}; + + // model loading from input + try { + model = torch::jit::load(std::string(arg[2]), device, meta_dict); + } catch (const c10::Error &e) { + error->all(FLERR, "error loading the model, check the path of the model"); + } + // model = torch::jit::freeze(model); model is already freezed + + torch::jit::setGraphExecutorOptimize(false); + torch::jit::FusionStrategy strategy; + // thing about dynamic recompile as tensor shape varies, this is default + // strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}}; + strategy = {{torch::jit::FusionBehavior::STATIC, 0}}; + torch::jit::setFusionStrategy(strategy); + + cutoff = std::stod(meta_dict["cutoff"]); + cutoff_square = cutoff * cutoff; + + // to make torch::autograd::grad() works + if (meta_dict["oeq"] == "yes") { + pair_e3gnn_oeq_register_autograd(); + } + + if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) { + error->all(FLERR, "given model type is not E3_equivariant_model"); + } + + std::string chem_str = meta_dict["chemical_symbols_to_index"]; + int ntypes = atom->ntypes; + + auto delim = " "; + char *tok = std::strtok(const_cast(chem_str.c_str()), delim); + std::vector chem_vec; + while (tok != nullptr) { + chem_vec.push_back(std::string(tok)); + tok = std::strtok(nullptr, delim); + } + + bool found_flag = false; + for (int i = 3; i < narg; i++) { + found_flag = false; + for (int j = 0; j < chem_vec.size(); j++) { + if (chem_vec[j].compare(arg[i]) == 0) { + map[i - 2] = j; + found_flag = true; + fprintf(lmp->logfile, "Chemical specie '%s' is assigned to type %d\n", + arg[i], i - 2); + break; + } + } + if (!found_flag) { + error->all(FLERR, "Unknown chemical specie is given"); + } + } + + if (ntypes > narg - 3) { + error->all(FLERR, "Not enough chemical specie is given. Check pair_coeff " + "and types in your data/script"); + } + + for (int i = 1; i <= ntypes; i++) { + for (int j = 1; j <= ntypes; j++) { + if ((map[i] >= 0) && (map[j] >= 0)) { + setflag[i][j] = 1; + cutsq[i][j] = cutoff * cutoff; + } + } + } + + if (lmp->logfile) { + fprintf(lmp->logfile, "from sevenn version '%s' ", + meta_dict["version"].c_str()); + fprintf(lmp->logfile, "%s precision model, deployed: %s\n", + meta_dict["dtype"].c_str(), meta_dict["time"].c_str()); + fprintf(lmp->logfile, "FlashTP: %s\n", + meta_dict["flashTP"].c_str()); + fprintf(lmp->logfile, "OEQ: %s\n", + meta_dict["oeq"].c_str()); + } +} + +// init specific to this pair +void PairE3GNN::init_style() { + // Newton flag is irrelevant if use only one processor for simulation + /* + if (force->newton_pair == 0) { + error->all(FLERR, "Pair style nn requires newton pair on"); + } + */ + + // full neighbor list (this is many-body potential) + neighbor->add_request(this, NeighConst::REQ_FULL); +} + +double PairE3GNN::init_one(int i, int j) { return cutoff; } diff --git a/sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp b/sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp new file mode 100644 index 00000000..2ec93909 --- /dev/null +++ b/sevenn/pair_e3gnn/pair_e3gnn_parallel_atomic_stress.cpp @@ -0,0 +1,913 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://lammps.sandia.gov/, Sandia National Laboratories + Steve Plimpton, sjplimp@sandia.gov + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributing author: Yutack Park (SNU) +------------------------------------------------------------------------- */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "atom.h" +#include "comm.h" +#include "comm_brick.h" +#include "error.h" +#include "force.h" +#include "memory.h" +#include "neigh_list.h" +#include "neighbor.h" +// #include "nvToolsExt.h" + +#include "pair_e3gnn_parallel.h" +#include + +#ifdef OMPI_MPI_H +#include "mpi-ext.h" //This should be included after mpi.h which is included in pair.h +#endif + +using namespace LAMMPS_NS; + +// Undefined reference; body in pair_e3gnn_oeq_autograd.cpp to be linked +extern void pair_e3gnn_oeq_register_autograd(); + +#define INTEGER_TYPE torch::TensorOptions().dtype(torch::kInt64) +#define FLOAT_TYPE torch::TensorOptions().dtype(torch::kFloat) + +DeviceBuffManager &DeviceBuffManager::getInstance() { + static DeviceBuffManager instance; + return instance; +} + +void DeviceBuffManager::get_buffer(int send_size, int recv_size, + float *&buf_send_ptr, float *&buf_recv_ptr) { + if (send_size > send_buf_size) { + cudaFree(buf_send_device); + cudaError_t cuda_err = + cudaMalloc(&buf_send_device, send_size * sizeof(float)); + send_buf_size = send_size; + } + if (recv_size > recv_buf_size) { + cudaFree(buf_recv_device); + cudaError_t cuda_err = + cudaMalloc(&buf_recv_device, recv_size * sizeof(float)); + recv_buf_size = recv_size; + } + buf_send_ptr = buf_send_device; + buf_recv_ptr = buf_recv_device; +} + +DeviceBuffManager::~DeviceBuffManager() { + cudaFree(buf_send_device); + cudaFree(buf_recv_device); +} + +PairE3GNNParallel::PairE3GNNParallel(LAMMPS *lmp) : Pair(lmp) { + // constructor + // Virial is accumulated directly from the model; skip LAMMPS fdotr post-pass. + no_virial_fdotr_compute = 1; + + const char *print_flag = std::getenv("SEVENN_PRINT_INFO"); + const char *print_both_flag = std::getenv("SEVENN_PRINT_BOTH_INFO"); + if (print_flag) { + world_rank = comm->me; + std::cout << "process rank: " << world_rank << " initialized" << std::endl; + print_info = (world_rank == 0) || print_both_flag; + } + + std::string device_name; + const bool use_gpu = torch::cuda::is_available(); + + comm_forward = 0; + comm_reverse = 0; + + // OpenMPI detection +#ifdef OMPI_MPI_H +#if defined(MPIX_CUDA_AWARE_SUPPORT) + if (1 == MPIX_Query_cuda_support()) { + use_cuda_mpi = true; + } else { + use_cuda_mpi = false; + } +#else + use_cuda_mpi = false; +#endif +#else + use_cuda_mpi = false; +#endif + // use_cuda_mpi = use_gpu && use_cuda_mpi; + // if (use_cuda_mpi) { + if (use_gpu) { + device = get_cuda_device(); + device_name = "CUDA"; + } else { + device = torch::kCPU; + device_name = "CPU"; + } + + if (std::getenv("OFF_E3GNN_PARALLEL_CUDA_MPI")) { + use_cuda_mpi = false; + } + + if (lmp->screen) { + if (use_gpu && !use_cuda_mpi) { + device_comm = torch::kCPU; + fprintf(lmp->screen, + "cuda-aware mpi not found, communicate via host device\n"); + } else { + device_comm = device; + } + fprintf(lmp->screen, "PairE3GNNParallel using device : %s\n", + device_name.c_str()); + fprintf(lmp->screen, "PairE3GNNParallel cuda-aware mpi: %s\n", + use_cuda_mpi ? "True" : "False"); + } + if (lmp->logfile) { + if (use_gpu && !use_cuda_mpi) { + device_comm = torch::kCPU; + fprintf(lmp->logfile, + "cuda-aware mpi not found, communicate via host device\n"); + } else { + device_comm = device; + } + fprintf(lmp->logfile, "PairE3GNNParallel using device : %s\n", + device_name.c_str()); + fprintf(lmp->logfile, "PairE3GNNParallel cuda-aware mpi: %s\n", + use_cuda_mpi ? "True" : "False"); + } +} + +torch::Device PairE3GNNParallel::get_cuda_device() { + char *cuda_visible = std::getenv("CUDA_VISIBLE_DEVICES"); + int num_gpus; + int idx; + int rank = comm->me; + num_gpus = torch::cuda::device_count(); + idx = rank % num_gpus; + if (print_info) + std::cout << world_rank << " Available # of GPUs found: " << num_gpus + << std::endl; + cudaError_t cuda_err = cudaSetDevice(idx); + if (cuda_err != cudaSuccess) { + std::cerr << "E3GNN: Failed to set CUDA device: " + << cudaGetErrorString(cuda_err) << std::endl; + } + return torch::Device(torch::kCUDA, idx); +} + +PairE3GNNParallel::~PairE3GNNParallel() { + if (allocated) { + memory->destroy(setflag); + memory->destroy(cutsq); + memory->destroy(map); + } +} + +int PairE3GNNParallel::get_x_dim() { return x_dim; } + +bool PairE3GNNParallel::use_cuda_mpi_() { return use_cuda_mpi; } + +bool PairE3GNNParallel::is_comm_preprocess_done() { + return comm_preprocess_done; +} + +void PairE3GNNParallel::compute(int eflag, int vflag) { + /* + Graph build on cpu + */ + if (eflag || vflag) + ev_setup(eflag, vflag); + else + evflag = vflag_fdotr = 0; +// if (vflag_atom) { +// error->all(FLERR, "atomic stress is not supported\n"); +// } + + if (atom->tag_consecutive() == 0) { + error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); + } + + double **x = atom->x; + double **f = atom->f; + int *type = atom->type; + int nlocal = list->inum; // same as nlocal + int nghost = atom->nghost; + int ntotal = nlocal + nghost; + int *ilist = list->ilist; + int inum = list->inum; + + CommBrick *comm_brick = dynamic_cast(comm); + if (comm_brick == nullptr) { + error->all(FLERR, "e3gnn/parallel: comm style should be brick & from " + "modified code of comm_brick"); + } + + bigint natoms = atom->natoms; + + // tag ignore PBC + tagint *tag = atom->tag; + + // store graph_idx from local to known ghost atoms(ghost atoms inside cutoff) + int tag_to_graph_idx[natoms + 1]; // tag starts from 1 not 0 + std::fill_n(tag_to_graph_idx, natoms + 1, -1); + + // to access tag_to_graph_idx from comm + tag_to_graph_idx_ptr = tag_to_graph_idx; + + int graph_indexer = nlocal; + int graph_index_to_i[ntotal]; + + int *numneigh = list->numneigh; // j loop cond + int **firstneigh = list->firstneigh; // j list + const int nedges_upper_bound = + std::accumulate(numneigh, numneigh + nlocal, 0); + + std::vector node_type; + std::vector node_type_ghost; + + float edge_vec[nedges_upper_bound][3]; + long edge_idx_src[nedges_upper_bound]; + long edge_idx_dst[nedges_upper_bound]; + + int nedges = 0; + for (int ii = 0; ii < inum; ii++) { + // populate tag_to_graph_idx of local atoms + const int i = ilist[ii]; + const int itag = tag[i]; + const int itype = type[i]; + tag_to_graph_idx[itag] = ii; + graph_index_to_i[ii] = i; + node_type.push_back(map[itype]); + } + + // loop over neighbors, build graph + for (int ii = 0; ii < inum; ii++) { + const int i = ilist[ii]; + const int i_graph_idx = ii; + const int *jlist = firstneigh[i]; + const int jnum = numneigh[i]; + + for (int jj = 0; jj < jnum; jj++) { + int j = jlist[jj]; + const int jtag = tag[j]; + j &= NEIGHMASK; + const int jtype = type[j]; + // we have to calculate Rij to check cutoff in lammps side + const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], + x[j][2] - x[i][2]}; + const double Rij = + delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; + + int j_graph_idx; + if (Rij < cutoff_square) { + // if given j is not local atom and inside cutoff + if (tag_to_graph_idx[jtag] == -1) { + // if j is ghost atom inside cutoff but first seen + tag_to_graph_idx[jtag] = graph_indexer; + graph_index_to_i[graph_indexer] = j; + node_type_ghost.push_back(map[jtype]); + graph_indexer++; + } + + j_graph_idx = tag_to_graph_idx[jtag]; + edge_idx_src[nedges] = i_graph_idx; + edge_idx_dst[nedges] = j_graph_idx; + edge_vec[nedges][0] = delij[0]; + edge_vec[nedges][1] = delij[1]; + edge_vec[nedges][2] = delij[2]; + nedges++; + } + } // j loop end + } // i loop end + + // member variable + graph_size = graph_indexer; + const int ghost_node_num = graph_size - nlocal; + + // convert data to Tensor + auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE); + auto inp_node_type_ghost = + torch::from_blob(node_type_ghost.data(), ghost_node_num, INTEGER_TYPE); + + long num_nodes[1] = {long(nlocal)}; + auto inp_num_atoms = torch::from_blob(num_nodes, {1}, INTEGER_TYPE); + + auto edge_idx_src_tensor = + torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); + auto edge_idx_dst_tensor = + torch::from_blob(edge_idx_dst, {nedges}, INTEGER_TYPE); + auto inp_edge_index = + torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); + + auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE); + if (print_info) { + std::cout << world_rank << " Nlocal: " << nlocal << std::endl; + std::cout << world_rank << " Graph_size: " << graph_size << std::endl; + std::cout << world_rank << " Ghost_node_num: " << ghost_node_num + << std::endl; + std::cout << world_rank << " Nedges: " << nedges << "\n" << std::endl; + } + + // r_original requires grad True + inp_edge_vec.set_requires_grad(true); + + torch::Dict input_dict; + input_dict.insert("x", inp_node_type.to(device)); + input_dict.insert("x_ghost", inp_node_type_ghost.to(device)); + input_dict.insert("edge_index", inp_edge_index.to(device)); + input_dict.insert("edge_vec", inp_edge_vec.to(device)); + input_dict.insert("num_atoms", inp_num_atoms.to(device)); + input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU)); + + std::list> wrt_tensors; + wrt_tensors.push_back({input_dict.at("edge_vec")}); + + auto model_part = model_list.front(); + + auto output = model_part.forward({input_dict}).toGenericDict(); + + comm_preprocess(); + + // extra_graph_idx_map is set from comm_preprocess(); + // last one is for trash values. See pack_forward_init + const int extra_size = + ghost_node_num + static_cast(extra_graph_idx_map.size()) + 1; + torch::Tensor x_local; + torch::Tensor x_ghost; + + for (auto it = model_list.begin(); it != model_list.end(); ++it) { + if (it == model_list.begin()) + continue; + model_part = *it; + + x_local = output.at("x").toTensor().detach().to(device); + x_dim = x_local.size(1); // length of per atom vector(node feature) + + auto ghost_and_extra_x = torch::zeros({ghost_node_num + extra_size, x_dim}, + FLOAT_TYPE.device(device)); + x_comm = torch::cat({x_local, ghost_and_extra_x}, 0).to(device_comm); + comm_brick->forward_comm(this); // populate x_ghost by communication + + // What we got from forward_comm (node feature of ghosts) + x_ghost = torch::split_with_sizes( + x_comm, {nlocal, ghost_node_num, extra_size}, 0)[1]; + x_ghost.set_requires_grad(true); + + // prepare next input (output > next input) + output.insert_or_assign("x_ghost", x_ghost.to(device)); + // make another edge_vec to discriminate grad calculation with other + // edge_vecs(maybe redundant?) + output.insert_or_assign("edge_vec", + output.at("edge_vec").toTensor().clone()); + + // save tensors for backprop + wrt_tensors.push_back({output.at("edge_vec").toTensor(), + output.at("x").toTensor(), + output.at("self_cont_tmp").toTensor(), + output.at("x_ghost").toTensor()}); + + output = model_part.forward({output}).toGenericDict(); + } + torch::Tensor energy_tensor = + output.at("inferred_total_energy").toTensor().squeeze(); + + torch::Tensor dE_dr = + torch::zeros({nedges, 3}, FLOAT_TYPE.device(device)); // create on device + torch::Tensor x_local_save; // holds grad info of x_local (it loses its grad + // when sends to CPU) + torch::Tensor self_conn_grads; + std::vector grads; + std::vector of_tensor; + + // TODO: most values of self_conn_grads were zero because we use only scalars + // for energy + for (auto rit = wrt_tensors.rbegin(); rit != wrt_tensors.rend(); ++rit) { + // edge_vec, x, x_ghost order + auto wrt_tensor = *rit; + if (rit == wrt_tensors.rbegin()) { + grads = torch::autograd::grad({energy_tensor}, wrt_tensor); + } else { + x_local_save.copy_(x_local); + // of wrt grads_output + grads = torch::autograd::grad(of_tensor, wrt_tensor, + {x_local_save, self_conn_grads}); + } + + dE_dr = dE_dr + grads.at(0); // accumulate force + if (std::distance(rit, wrt_tensors.rend()) == 1) + continue; // if last iteration + + of_tensor.clear(); + of_tensor.push_back(wrt_tensor[1]); // x + of_tensor.push_back(wrt_tensor[2]); // self_cont_tmp + + x_local_save = grads.at(1); // for grads_output + x_local = x_local_save.detach(); // grad_outputs & communication + x_dim = x_local.size(1); + + self_conn_grads = grads.at(2); // no communication, for grads_output + + x_ghost = grads.at(3).detach(); // yes communication, not for grads_output + + auto extra_x = torch::zeros({extra_size, x_dim}, FLOAT_TYPE.device(device)); + x_comm = torch::cat({x_local, x_ghost, extra_x}, 0).to(device_comm); + + comm_brick->reverse_comm(this); // completes x_local + + // now x_local is complete (dE_dx), become next grads_output(with + // self_conn_grads) + x_local = torch::split_with_sizes( + x_comm, {nlocal, ghost_node_num, extra_size}, 0)[0]; + } + + // postprocessing + if (print_info) { + size_t free, tot; + cudaMemGetInfo(&free, &tot); + std::cout << world_rank << " MEM use after backward(MB)" << std::endl; + double Mfree = static_cast(free) / (1024 * 1024); + double Mtot = static_cast(tot) / (1024 * 1024); + std::cout << world_rank << " Total: " << Mtot << std::endl; + std::cout << world_rank << " Free: " << Mfree << std::endl; + std::cout << world_rank << " Used: " << Mtot - Mfree << std::endl; + double Mused = Mtot - Mfree; + std::cout << world_rank << " Used/Nedges: " << Mused / nedges << std::endl; + std::cout << world_rank << " Used/Nlocal: " << Mused / nlocal << std::endl; + std::cout << world_rank << " Used/GraphSize: " << Mused / graph_size << "\n" + << std::endl; + } + eng_vdwl += energy_tensor.item(); // accumulate energy + + dE_dr = dE_dr.to(torch::kCPU); + torch::Tensor force_tensor = torch::zeros({graph_indexer, 3}); + + auto _edge_idx_src_tensor = + edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}); + auto _edge_idx_dst_tensor = + edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}); + + force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum"); + force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr), + "sum"); + + auto forces = force_tensor.accessor(); + + for (int graph_idx = 0; graph_idx < graph_indexer; graph_idx++) { + int i = graph_index_to_i[graph_idx]; + f[i][0] += forces[graph_idx][0]; + f[i][1] += forces[graph_idx][1]; + f[i][2] += forces[graph_idx][2]; + } + + if (vflag) { + auto diag = inp_edge_vec * dE_dr; + auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1); + auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2); + auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0); + std::vector voigt_list = { + diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)}; + auto voigt = torch::cat(voigt_list, 1); + + torch::Tensor per_atom_stress_tensor = torch::zeros({graph_indexer, 6}); + auto _edge_idx_dst6_tensor = + edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6}); + per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt, + "sum"); + auto virial_stress_tensor = + torch::neg(torch::sum(per_atom_stress_tensor, 0)); + auto virial_stress = virial_stress_tensor.accessor(); + + virial[0] += virial_stress[0]; + virial[1] += virial_stress[1]; + virial[2] += virial_stress[2]; + virial[3] += virial_stress[3]; + virial[4] += virial_stress[5]; + virial[5] += virial_stress[4]; + } + + if (eflag_atom) { + torch::Tensor atomic_energy_tensor = + output.at("atomic_energy").toTensor().cpu().view({nlocal}); + auto atomic_energy = atomic_energy_tensor.accessor(); + for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) { + int i = graph_index_to_i[graph_idx]; + eatom[i] += atomic_energy[graph_idx]; + } + } + + // clean up comm preprocess variables + comm_preprocess_done = false; + for (int i = 0; i < 6; i++) { + // array of vector + comm_index_pack_forward[i].clear(); + comm_index_unpack_forward[i].clear(); + comm_index_unpack_reverse[i].clear(); + } + + extra_graph_idx_map.clear(); +} + +// allocate arrays (called from coeff) +void PairE3GNNParallel::allocate() { + allocated = 1; + int n = atom->ntypes; + + memory->create(setflag, n + 1, n + 1, "pair:setflag"); + memory->create(cutsq, n + 1, n + 1, "pair:cutsq"); + memory->create(map, n + 1, "pair:map"); +} + +// global settings for pair_style +void PairE3GNNParallel::settings(int narg, char **arg) { + if (narg != 0) { + error->all(FLERR, "Illegal pair_style command"); + } +} + +void PairE3GNNParallel::coeff(int narg, char **arg) { + if (allocated) { + error->all(FLERR, "pair_e3gnn coeff called twice"); + } + allocate(); + + if (strcmp(arg[0], "*") != 0 || strcmp(arg[1], "*") != 0) { + error->all(FLERR, + "e3gnn: first and second input of pair_coeff should be '*'"); + } + // expected input : pair_coeff * * pot.pth type_name1 type_name2 ... + + std::unordered_map meta_dict = { + {"chemical_symbols_to_index", ""}, + {"cutoff", ""}, + {"num_species", ""}, + {"model_type", ""}, + {"version", ""}, + {"dtype", ""}, + {"time", ""}, + {"flashTP", "version mismatch"}, + {"oeq", "version mismatch"}, + {"comm_size", ""}}; + + // model loading from input + int n_model = std::stoi(arg[2]); + int chem_arg_i = 4; + std::vector model_fnames; + if (std::filesystem::exists(arg[3])) { + if (std::filesystem::is_directory(arg[3])) { + auto headf = std::string(arg[3]); + for (int i = 0; i < n_model; i++) { + auto stri = std::to_string(i); + model_fnames.push_back(headf + "/deployed_parallel_" + stri + ".pt"); + } + } else if (std::filesystem::is_regular_file(arg[3])) { + for (int i = 3; i < n_model + 3; i++) { + model_fnames.push_back(std::string(arg[i])); + } + chem_arg_i = n_model + 3; + } else { + error->all(FLERR, "No such file or directory:" + std::string(arg[3])); + } + } + + for (const auto &modelf : model_fnames) { + if (!std::filesystem::is_regular_file(modelf)) { + error->all(FLERR, "Expected this is a regular file:" + modelf); + } + model_list.push_back(torch::jit::load(modelf, device, meta_dict)); + } + + torch::jit::setGraphExecutorOptimize(false); + torch::jit::FusionStrategy strategy; + // strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}}; + strategy = {{torch::jit::FusionBehavior::STATIC, 0}}; + torch::jit::setFusionStrategy(strategy); + + cutoff = std::stod(meta_dict["cutoff"]); + + // maximum possible size of per atom x before last convolution + int comm_size = std::stod(meta_dict["comm_size"]); + + // to initialize buffer size for communication + comm_forward = comm_size; + comm_reverse = comm_size; + + cutoff_square = cutoff * cutoff; + + // to make torch::autograd::grad() works + if (meta_dict["oeq"] == "yes") { + pair_e3gnn_oeq_register_autograd(); + } + + if (meta_dict["model_type"].compare("E3_equivariant_model") != 0) { + error->all(FLERR, "given model type is not E3_equivariant_model"); + } + + std::string chem_str = meta_dict["chemical_symbols_to_index"]; + int ntypes = atom->ntypes; + + auto delim = " "; + char *tok = std::strtok(const_cast(chem_str.c_str()), delim); + std::vector chem_vec; + while (tok != nullptr) { + chem_vec.push_back(std::string(tok)); + tok = std::strtok(nullptr, delim); + } + + // what if unknown chemical specie is in arg? should I abort? is there any use + // case for that? + bool found_flag = false; + int n_chem = narg - chem_arg_i; + for (int i = 0; i < n_chem; i++) { + found_flag = false; + for (int j = 0; j < chem_vec.size(); j++) { + if (chem_vec[j].compare(arg[i + chem_arg_i]) == 0) { + map[i + 1] = j; // store from 1, (not 0) + found_flag = true; + if (lmp->logfile) { + fprintf(lmp->logfile, "Chemical specie '%s' is assigned to type %d\n", + arg[i + chem_arg_i], i + 1); + break; + } + } + } + if (!found_flag) { + error->all(FLERR, "Unknown chemical specie is given or the number of " + "potential files is not consistent"); + } + } + + for (int i = 1; i <= ntypes; i++) { + for (int j = 1; j <= ntypes; j++) { + if ((map[i] >= 0) && (map[j] >= 0)) { + setflag[i][j] = 1; + cutsq[i][j] = cutoff * cutoff; + } + } + } + + if (lmp->logfile) { + fprintf(lmp->logfile, "from sevenn version '%s' ", + meta_dict["version"].c_str()); + fprintf(lmp->logfile, "%s precision model, deployed: %s\n", + meta_dict["dtype"].c_str(), meta_dict["time"].c_str()); + fprintf(lmp->logfile, "FlashTP: %s\n", + meta_dict["flashTP"].c_str()); + fprintf(lmp->logfile, "OEQ: %s\n", + meta_dict["oeq"].c_str()); + } +} + +// init specific to this pair +void PairE3GNNParallel::init_style() { + // full neighbor list & newton on + if (force->newton_pair == 0) { + error->all(FLERR, "Pair style e3gnn/parallel requires newton pair on"); + } + neighbor->add_request(this, NeighConst::REQ_FULL); +} + +double PairE3GNNParallel::init_one(int i, int j) { return cutoff; } + +void PairE3GNNParallel::notify_proc_ids(const int *sendproc, const int *recvproc) { + for (int iswap = 0; iswap < 6; iswap++) { + this->sendproc[iswap] = sendproc[iswap]; + this->recvproc[iswap]= recvproc[iswap]; + } +} + +void PairE3GNNParallel::comm_preprocess() { + assert(!comm_preprocess_done); + CommBrick *comm_brick = dynamic_cast(comm); + + // fake lammps communication call to preprocess index + // gives complete comm_index_pack, unpack_forward, and extra_graph_idx_map + comm_brick->forward_comm(this); + + std::map> already_met_map; + for (int comm_phase = 0; comm_phase < 6; comm_phase++) { + const int n = comm_index_pack_forward[comm_phase].size(); + int sproc = this->sendproc[comm_phase]; + if (already_met_map.count(sproc) == 0) { + already_met_map.insert({sproc, std::set()}); + } + + // for unpack_reverse, Ignore duplicated index by 'already_met' + std::vector &idx_map_forward = comm_index_pack_forward[comm_phase]; + std::vector &idx_map_reverse = comm_index_unpack_reverse[comm_phase]; + std::set& already_met = already_met_map[sproc]; + // the last index of x_comm is used to trash unnecessary values + const int trash_index = + graph_size + static_cast(extra_graph_idx_map.size()); //+ 1; + for (int i = 0; i < n; i++) { + const int idx = idx_map_forward[i]; + if (idx < graph_size) { + if (already_met.count(idx) == 1) { + idx_map_reverse.push_back(trash_index); + } else { + idx_map_reverse.push_back(idx); + already_met.insert(idx); + } + } else { + idx_map_reverse.push_back(idx); + } + } + + if (use_cuda_mpi) { + comm_index_pack_forward_tensor[comm_phase] = torch::from_blob(idx_map_forward.data(), idx_map_forward.size(), INTEGER_TYPE).to(device); + + auto upmap = comm_index_unpack_forward[comm_phase]; + comm_index_unpack_forward_tensor[comm_phase] = torch::from_blob(upmap.data(), upmap.size(), INTEGER_TYPE).to(device); + comm_index_unpack_reverse_tensor[comm_phase] = torch::from_blob(idx_map_reverse.data(), idx_map_reverse.size(), INTEGER_TYPE).to(device); + } + } + comm_preprocess_done = true; +} + +// called from comm_brick if comm_preprocess_done is false +void PairE3GNNParallel::pack_forward_init(int n, int *list_send, + int comm_phase) { + std::vector &idx_map = comm_index_pack_forward[comm_phase]; + + idx_map.reserve(n); + + int i, j; + int nlocal = list->inum; + tagint *tag = atom->tag; + + for (i = 0; i < n; i++) { + int list_i = list_send[i]; + int graph_idx = tag_to_graph_idx_ptr[tag[list_i]]; + + if (graph_idx != -1) { + // known atom (local atom + ghost atom inside cutoff) + idx_map.push_back(graph_idx); + } else { + // unknown atom, these are not used in computation in this process + // instead, this process is used to hand over these atoms to other proecss + // hold them in continuous manner for flexible tensor operations later + if (extra_graph_idx_map.find(list_i) != extra_graph_idx_map.end()) { + idx_map.push_back(extra_graph_idx_map[list_i]); + } else { + // unknown atom at pack forward, ghost atom outside cutoff? + extra_graph_idx_map[i] = graph_size + extra_graph_idx_map.size(); + idx_map.push_back(extra_graph_idx_map[i]); // same as list_i in pack + } + } + } +} + +// called from comm_brick if comm_preprocess_done is false +void PairE3GNNParallel::unpack_forward_init(int n, int first, int comm_phase) { + std::vector &idx_map = comm_index_unpack_forward[comm_phase]; + + idx_map.reserve(n); + + int i, j, last; + last = first + n; + int nlocal = list->inum; + tagint *tag = atom->tag; + + for (i = first; i < last; i++) { + int graph_idx = tag_to_graph_idx_ptr[tag[i]]; + if (graph_idx != -1) { + idx_map.push_back(graph_idx); + } else { + extra_graph_idx_map[i] = graph_size + extra_graph_idx_map.size(); + idx_map.push_back(extra_graph_idx_map[i]); // same as list_i in pack + } + } +} + +int PairE3GNNParallel::pack_forward_comm_gnn(float *buf, int comm_phase) { + std::vector &idx_map = comm_index_pack_forward[comm_phase]; + const int n = static_cast(idx_map.size()); + if (use_cuda_mpi && n != 0) { + torch::Tensor &idx_map_tensor = comm_index_pack_forward_tensor[comm_phase]; + auto selected = x_comm.index_select(0, idx_map_tensor); // its size is x_dim * n + cudaError_t cuda_err = + cudaMemcpy(buf, selected.data_ptr(), (x_dim * n) * sizeof(float), + cudaMemcpyDeviceToDevice); + } else { + int i, j, m; + m = 0; + for (i = 0; i < n; i++) { + const int idx = static_cast(idx_map.at(i)); + float *from = x_comm[idx].data_ptr(); + for (j = 0; j < x_dim; j++) { + buf[m++] = from[j]; + } + } + } + if (print_info) { + std::cout << world_rank << " comm_phase: " << comm_phase << std::endl; + std::cout << world_rank << " pack_forward x_dim: " << x_dim << std::endl; + std::cout << world_rank << " pack_forward n: " << n << std::endl; + std::cout << world_rank << " pack_forward x_dim*n: " << x_dim * n + << std::endl; + double Msend = static_cast(x_dim * n * 4) / (1024 * 1024); + std::cout << world_rank << " send size(MB): " << Msend << "\n" << std::endl; + } + return x_dim * n; +} + +void PairE3GNNParallel::unpack_forward_comm_gnn(float *buf, int comm_phase) { + std::vector &idx_map = comm_index_unpack_forward[comm_phase]; + const int n = static_cast(idx_map.size()); + + if (use_cuda_mpi && n != 0) { + torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase]; + auto buf_tensor = + torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device)); + x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}), + buf_tensor); + } else { + int i, j, m; + m = 0; + for (i = 0; i < n; i++) { + const int idx = static_cast(idx_map.at(i)); + float *to = x_comm[idx].data_ptr(); + for (j = 0; j < x_dim; j++) { + to[j] = buf[m++]; + } + } + } +} + +int PairE3GNNParallel::pack_reverse_comm_gnn(float *buf, int comm_phase) { + std::vector &idx_map = comm_index_unpack_forward[comm_phase]; + const int n = static_cast(idx_map.size()); + + if (use_cuda_mpi && n != 0) { + torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase]; + auto selected = x_comm.index_select(0, idx_map_tensor); + cudaError_t cuda_err = cudaMemcpy(buf, selected.data_ptr(), (x_dim * n) * sizeof(float), cudaMemcpyDeviceToDevice); + } else { + int i, j, m; + m = 0; + for (i = 0; i < n; i++) { + const int idx = static_cast(idx_map.at(i)); + float *from = x_comm[idx].data_ptr(); + for (j = 0; j < x_dim; j++) { + buf[m++] = from[j]; + } + } + } + if (print_info) { + std::cout << world_rank << " comm_phase: " << comm_phase << std::endl; + std::cout << world_rank << " pack_reverse x_dim: " << x_dim << std::endl; + std::cout << world_rank << " pack_reverse n: " << n << std::endl; + std::cout << world_rank << " pack_reverse x_dim*n: " << x_dim * n + << std::endl; + double Msend = static_cast(x_dim * n * 4) / (1024 * 1024); + } + return x_dim * n; +} + +void PairE3GNNParallel::unpack_reverse_comm_gnn(float *buf, int comm_phase) { + std::vector &idx_map = comm_index_unpack_reverse[comm_phase]; + const int n = static_cast(idx_map.size()); + + if (use_cuda_mpi && n != 0) { + torch::Tensor &idx_map_tensor = comm_index_unpack_reverse_tensor[comm_phase]; + auto buf_tensor = + torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device)); + x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}), + buf_tensor, "add"); + } else { + int i, j, m; + m = 0; + for (i = 0; i < n; i++) { + const int idx = static_cast(idx_map.at(i)); + if (idx == -1) { + m += x_dim; + continue; + } + float *to = x_comm[idx].data_ptr(); + for (j = 0; j < x_dim; j++) { + to[j] += buf[m++]; + } + } + } +} diff --git a/sevenn/pair_e3gnn/patch_lammps.sh b/sevenn/pair_e3gnn/patch_lammps.sh index ec46c1ba..e1b7eec9 100755 --- a/sevenn/pair_e3gnn/patch_lammps.sh +++ b/sevenn/pair_e3gnn/patch_lammps.sh @@ -5,6 +5,7 @@ cxx_standard=$2 # 14, 17 d3_support=$3 # 1, 0 flashTP_so="${4:-NONE}" oeq_so="${5:-NONE}" +atomic_stress="${6:-0}" # 1, 0 SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") enable_flashTP=0 enable_oeq=0 @@ -14,8 +15,8 @@ enable_oeq=0 ########################################### # Check the number of arguments -if [ "$#" -lt 3 ] || [ "$#" -gt 5 ]; then - echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support} {flashTP_so} {oeq_so}" +if [ "$#" -lt 3 ] || [ "$#" -gt 6 ]; then + echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support} {flashTP_so} {oeq_so} {atomic_stress}" echo " {lammps_root}: Root directory of LAMMPS source" echo " {cxx_standard}: C++ standard (14, 17)" echo " {d3_support}: Support for pair_d3 (1, 0)" @@ -119,7 +120,13 @@ cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt ########################################### # 1. Copy pair_e3gnn files to LAMMPS source +if [ "$atomic_stress" -eq 1 ]; then +cp $SCRIPT_DIR/pair_e3gnn_atomic_stress.cpp $lammps_root/src/pair_e3gnn.cpp +cp $SCRIPT_DIR/pair_e3gnn_parallel_atomic_stress.cpp $lammps_root/src/pair_e3gnn_parallel.cpp +cp $SCRIPT_DIR/comm_brick.cpp $lammps_root/src/ +else cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ +fi cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ # Always copy the oEq autograd bridge (pair_e3gnn.cpp has an extern reference to it) cp $SCRIPT_DIR/pair_e3gnn_oeq_autograd.cpp $lammps_root/src/ # TODO: set this as oeq-specific @@ -199,6 +206,9 @@ fi echo "Changes made:" echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" echo " - Copied contents of pair_e3gnn to $lammps_root/src/" +if [ "$atomic_stress" -eq 1 ]; then + echo " - Atomic stress patch mode enabled: using pair_e3gnn*_atomic_stress.cpp" +fi echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" echo if [ "$enable_flashTP" -eq 1 ]; then diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index c81a9ad1..573ccc78 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -19,12 +19,9 @@ def deploy( modal: Optional[str] = None, use_flash: bool = False, use_oeq: bool = False, + atomic_virial: bool = False, ) -> None: - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ForceStressOutput - cp = load_checkpoint(checkpoint) - model, config = ( cp.build_model( enable_cueq=False, @@ -35,11 +32,8 @@ def deploy( cp.config, ) - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutput() - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key + if 'force_output' in model._modules: + model.delete_module_by_key('force_output') if hasattr(model, 'eval_type_map'): setattr(model, 'eval_type_map', False) @@ -74,6 +68,7 @@ def deploy( md_configs.update({'version': __version__}) md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + md_configs.update({'atomic_virial': 'yes' if atomic_virial else 'no'}) if fname.endswith('.pt') is False: fname += '.pt' @@ -175,3 +170,79 @@ def deploy_parallel( model = torch.jit.freeze(model) torch.jit.save(model, fname_full, _extra_files=md_configs) + + +def deploy_ts( + checkpoint: Union[pathlib.Path, str], + fname='deployed_model.pt', + modal: Optional[str] = None, + use_flash: bool = False, + use_oeq: bool = False, + atomic_virial: bool = False, +) -> None: + ''' + only for SevenNetCalculator with torchscript input (not for e3gnn) + ''' + from sevenn.nn.edge_embedding import EdgePreprocess + from sevenn.nn.force_output import ( + ForceStressOutput, + ForceStressOutputFromEdge, + ) + + cp = load_checkpoint(checkpoint) + + model, config = ( + cp.build_model( + enable_cueq=False, + enable_flash=use_flash, + enable_oeq=use_oeq, + _flash_lammps=use_flash, + ), + cp.config, + ) + + model.prepand_module('edge_preprocess', EdgePreprocess(True)) + grad_module = ForceStressOutputFromEdge(use_atomic_virial=True) + model.replace_module('force_output', grad_module) + new_grad_key = grad_module.get_grad_key() + model.key_grad = new_grad_key + + if hasattr(model, 'eval_type_map'): + setattr(model, 'eval_type_map', False) + + if modal: + model.prepare_modal_deploy(modal) + elif model.modal_map is not None and len(model.modal_map) >= 1: + raise ValueError( + f'Modal is not given. It has: {list(model.modal_map.keys())}' + ) + + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + # make some config need for md + md_configs = {} + type_map = config[KEY.TYPE_MAP] + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update({'flashTP': 'yes' if use_flash else 'no'}) + md_configs.update({'oeq': 'yes' if use_oeq else 'no'}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + md_configs.update({'atomic_virial': 'yes' if atomic_virial else 'no'}) + + if fname.endswith('.pt') is False: + fname += '.pt' + torch.jit.save(model, fname, _extra_files=md_configs) diff --git a/sevenn/scripts/inference.py b/sevenn/scripts/inference.py index c3e83428..70a8401e 100644 --- a/sevenn/scripts/inference.py +++ b/sevenn/scripts/inference.py @@ -18,6 +18,7 @@ def write_inference_csv(output_list, out): output = output.fit_dimension() output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208 output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208 + # atomic virial: keep model units (energy-like) output_list[i] = output.to_numpy_dict() per_graph_keys = [ @@ -35,6 +36,7 @@ def write_inference_csv(output_list, out): KEY.POS, KEY.FORCE, KEY.PRED_FORCE, + KEY.PRED_ATOMIC_VIRIAL, ] def unfold_dct_val(dct, keys, suffix_list=None): @@ -53,15 +55,36 @@ def unfold_dct_val(dct, keys, suffix_list=None): return res def per_atom_dct_list(dct, keys): - sfx_list = ['x', 'y', 'z'] res = [] - natoms = dct[KEY.NUM_ATOMS] - extracted = {k: dct[k] for k in keys} + natoms = int(dct[KEY.NUM_ATOMS]) for i in range(natoms): - raw = {} - raw.update({k: v[i] for k, v in extracted.items()}) - per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list) - res.append(per_atom_dct) + entry = {} + for k in keys: + if k not in dct: + continue + v = dct[k] + if isinstance(v, np.ndarray): + if v.ndim == 0: + entry[k] = v.item() + elif v.ndim == 1: + entry[k] = v[i] + elif v.ndim == 2: + d = v.shape[1] + if k in (KEY.POS, KEY.FORCE, KEY.PRED_FORCE): + sfx = ['x', 'y', 'z'] + elif k == KEY.PRED_ATOMIC_VIRIAL: + sfx = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] + else: + sfx = [str(j) for j in range(d)] + for j in range(d): + entry[f'{k}_{sfx[j]}'] = v[i, j] + else: + flat = v[i].ravel() + for j, val in enumerate(flat): + entry[f'{k}_{j}'] = val + else: + entry[k] = v + res.append(entry) return res try: diff --git a/sevenn/util.py b/sevenn/util.py index d7ce7888..20cf0464 100644 --- a/sevenn/util.py +++ b/sevenn/util.py @@ -43,6 +43,12 @@ def to_atom_graph_list(atom_graph_batch) -> List[_const.AtomGraphDataType]: if is_stress: inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS]) + inferred_virial_list = None + if KEY.PRED_ATOMIC_VIRIAL in atom_graph_batch: + inferred_virial_list = torch.split( + atom_graph_batch[KEY.PRED_ATOMIC_VIRIAL], indices + ) + for i, data in enumerate(data_list): data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i] data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i] @@ -50,6 +56,8 @@ def to_atom_graph_list(atom_graph_batch) -> List[_const.AtomGraphDataType]: # To fit with KEY.STRESS (ref) format if is_stress and inferred_stress_list is not None: data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0) + if inferred_virial_list is not None: + data[KEY.PRED_ATOMIC_VIRIAL] = inferred_virial_list[i] return data_list diff --git a/tests/lammps_tests/scripts/stress_skel.lmp b/tests/lammps_tests/scripts/stress_skel.lmp new file mode 100644 index 00000000..10b69e8c --- /dev/null +++ b/tests/lammps_tests/scripts/stress_skel.lmp @@ -0,0 +1,21 @@ + units metal + boundary __BOUNDARY__ + read_data __LMP_STCT__ + + mass * 1.0 # do not matter since we don't run MD + + pair_style __PAIR_STYLE__ + pair_coeff * * __POTENTIALS__ __ELEMENT__ + + timestep 0.002 + + compute pa all pe/atom + compute astress all stress/atom NULL pair + + thermo 1 + fix 1 all nve + thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp + dump mydump all custom 1 __FORCE_DUMP_PATH__ id type element c_pa c_astress[1] c_astress[2] c_astress[3] c_astress[4] c_astress[5] c_astress[6] x y z fx fy fz + dump_modify mydump sort id element __ELEMENT__ + + run 0 diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index 8847feea..502fd3e1 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -18,7 +18,7 @@ from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available from sevenn.nn.oeq_helper import is_oeq_available -from sevenn.scripts.deploy import deploy, deploy_parallel +from sevenn.scripts.deploy import deploy, deploy_parallel, deploy_ts from sevenn.util import chemical_species_preprocess, pretrained_name_to_path logger = logging.getLogger('test_lammps') @@ -28,6 +28,9 @@ lmp_script_path = str( (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() ) +lmp_stress_script_path = str( + (pathlib.Path(__file__).parent / 'scripts' / 'stress_skel.lmp').resolve() +) data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O @@ -106,6 +109,22 @@ def ref_modal_calculator(): return SevenNetCalculator(cp_mf_path, modal='PBE') +@pytest.fixture(scope='module') +def ref_stress_calculator(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') + pot_path = str(tmp / 'deployed_atomic_virial.pt') + deploy_ts(cp_0_path, pot_path, atomic_virial=True) + return SevenNetCalculator(pot_path, file_type='torchscript') + + +@pytest.fixture(scope='module') +def ref_7net0_stress_calculator(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential_atomic_virial') + pot_path = str(tmp / 'deployed_atomic_virial.pt') + deploy_ts(cp_7net0_path, pot_path, atomic_virial=True) + return SevenNetCalculator(pot_path, file_type='torchscript') + + def get_model_config(): config = { 'cutoff': cutoff, @@ -176,7 +195,7 @@ def get_system(system_name, **kwargs): raise ValueError() -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6, check_atomic_stress=False): def acl(a, b, rtol=rtol, atol=atol): return np.allclose(a, b, rtol=rtol, atol=atol) @@ -190,6 +209,11 @@ def acl(a, b, rtol=rtol, atol=atol): rtol * 10, atol * 10, ) + if check_atomic_stress: + ref_atomic_virial = np.asarray(atoms1.calc.results['stresses']) + lmp_atomic_stress = np.asarray(atoms2.calc.results['atomic_stress']) + lmp_atomic_virial = -lmp_atomic_stress[:, [0, 1, 2, 3, 5, 4]] + assert acl(ref_atomic_virial, lmp_atomic_virial, rtol * 10, atol * 10) # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) @@ -219,7 +243,7 @@ def _lammps_results_to_atoms(lammps_log, force_dump): 'lmp_dump': force_dump, } # atomic energy read - latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] + latoms.calc.results['energies'] = np.ravel(latoms.arrays['c_pa']) stress = np.array( [ [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], @@ -230,10 +254,41 @@ def _lammps_results_to_atoms(lammps_log, force_dump): stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 latoms.calc.results['stress'] = stress + if 'c_astress[1]' in latoms.arrays: + atomic_stress = np.column_stack( + [ + np.asarray(latoms.arrays['c_astress[1]']), + np.asarray(latoms.arrays['c_astress[2]']), + np.asarray(latoms.arrays['c_astress[3]']), + np.asarray(latoms.arrays['c_astress[4]']), + np.asarray(latoms.arrays['c_astress[5]']), + np.asarray(latoms.arrays['c_astress[6]']), + ] + ) + latoms.calc.results['atomic_stress'] = atomic_stress / 1602.1766208 / 1000 + return latoms -def _run_lammps(atoms, pair_style, potential, wd, command, test_name): +def _run_lammps(atoms, pair_style, potential, wd, command, test_name, script_path): + def _rotate_stress(atomic_stress, rot_mat): + out = np.empty_like(atomic_stress) + for i, s in enumerate(atomic_stress): + sigma = np.array([ + [s[0], s[3], s[4]], + [s[3], s[1], s[5]], + [s[4], s[5], s[2]] + ]) + sigma = rot_mat @ sigma @ rot_mat.T + out[i] = [ + sigma[0, 0], + sigma[1, 1], + sigma[2, 2], + sigma[0, 1], + sigma[0, 2], + sigma[1, 2] + ] + return out wd = wd.resolve() pbc = atoms.get_pbc() pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) @@ -248,7 +303,7 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_stct, atoms, prismobj=prism, specorder=chem ) - with open(lmp_script_path, 'r') as f: + with open(script_path, 'r') as f: cont = f.read() lammps_log = str(wd / 'log.lammps') @@ -276,6 +331,10 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): rot_mat = prism.rot_mat results = copy.deepcopy(lmp_atoms.calc.results) + + # SinglePointCalculator does not know atomic_stress + at_stress = results.pop('atomic_stress', None) + r_force = np.dot(results['forces'], rot_mat.T) results['forces'] = r_force if 'stress' in results: @@ -287,19 +346,33 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_atoms.set_cell(r_cell, scale_atoms=True) lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() + if at_stress is not None: + lmp_atoms.calc.results['atomic_stress'] = _rotate_stress(at_stress, rot_mat) + return lmp_atoms def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): command = lammps_cmd - return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_script_path + ) def parallel_lammps_run( atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd ): command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' - return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn/parallel', potential, wd, command, test_name, lmp_script_path + ) + + +def serial_stress_lammps_run(atoms, potential, wd, test_name, lammps_cmd): + command = lammps_cmd + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_stress_script_path + ) def subprocess_routine(cmd, name): @@ -370,6 +443,65 @@ def test_serial_flash( assert_atoms(atoms, atoms_lammps, atol=1e-5) +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress( + system, serial_potential_path, ref_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path, + wd=tmp_path, + test_name='serial stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_oeq( + system, serial_potential_path_oeq, ref_7net0_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_oeq, + wd=tmp_path, + test_name='serial oeq stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_flash_available(), reason='flash not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_flash( + system, serial_potential_path_flash, ref_7net0_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_flash, + wd=tmp_path, + test_name='serial flash stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + @pytest.mark.parametrize( 'system,ncores', [ diff --git a/tests/unit_tests/test_atomic_virial.py b/tests/unit_tests/test_atomic_virial.py new file mode 100644 index 00000000..3bc56237 --- /dev/null +++ b/tests/unit_tests/test_atomic_virial.py @@ -0,0 +1,61 @@ +import numpy as np +from ase.build import bulk + +import sevenn._keys as KEY +from sevenn.calculator import SevenNetCalculator +from sevenn.scripts.deploy import deploy_ts +from sevenn.util import pretrained_name_to_path + + +def _get_atoms_pbc(): + atoms = bulk('NaCl', 'rocksalt', a=5.63) + atoms.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms + + +def test_atomic_virial_is_exposed_in_python_torchscript_path(tmp_path): + model_path = str(tmp_path / '7net_0_atomic_virial.pt') + deploy_ts(pretrained_name_to_path('7net-0_11July2024'), model_path, atomic_virial=True) + + calc = SevenNetCalculator(model_path, file_type='torchscript') + atoms = _get_atoms_pbc() + atoms.calc = calc + _ = atoms.get_potential_energy() + + assert 'stresses' in calc.results + atomic_virial = np.asarray(calc.results['stresses']) + + assert atomic_virial.shape == (len(atoms), 6) + assert np.isfinite(atomic_virial).all() + assert np.any(np.abs(atomic_virial) > 0.0) + # Full 6-component per-atom reference check. + # Sort rows for deterministic comparison when atom-wise ordering changes. + atomic_virial_ref = np.array( + [ + [10.06461800, -0.07430478, -0.07463801, -0.38235345, -0.04390856, -0.47967120], + [13.55999100, 1.41123860, 1.41158500, -1.28466740, -0.17591128, -1.18766780], + ] + ) + atomic_virial_sorted = atomic_virial[np.argsort(atomic_virial[:, 0])] + atomic_virial_ref_sorted = atomic_virial_ref[np.argsort(atomic_virial_ref[:, 0])] + assert np.allclose(atomic_virial_sorted, atomic_virial_ref_sorted, atol=1e-4) + + # Internal model stress ordering is [xx, yy, zz, xy, yz, zx]. + # Calculator exposes ASE stress as -stress_internal with + # component order [xx, yy, zz, yz, zx, xy]. + virial_sum = atomic_virial.sum(axis=0) + virial_sum_ref = np.array([ + 23.62459886, + 1.33693361, + 1.33694608, + -1.66702306, + -0.21981908, + -1.66734152, + ]) + assert np.allclose(virial_sum, virial_sum_ref, atol=1e-4) + + stress_internal_from_virial = virial_sum / atoms.get_volume() + stress_ase_from_virial = -stress_internal_from_virial[[0, 1, 2, 4, 5, 3]] + + assert np.allclose(calc.results['stress'], stress_ase_from_virial, atol=1e-5) diff --git a/tests/unit_tests/test_calculator.py b/tests/unit_tests/test_calculator.py index 9b19308d..e4aacf2e 100644 --- a/tests/unit_tests/test_calculator.py +++ b/tests/unit_tests/test_calculator.py @@ -8,7 +8,7 @@ from sevenn.calculator import D3Calculator, SevenNetCalculator from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available -from sevenn.scripts.deploy import deploy +from sevenn.scripts.deploy import deploy_ts from sevenn.util import model_from_checkpoint, pretrained_name_to_path @@ -108,7 +108,7 @@ def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): atoms_pbc.rattle(stdev=0.01, seed=42) fname = str(tmp_path / '7net_0.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), fname) + deploy_ts(pretrained_name_to_path('7net-0_11July2024'), fname) calc_script = SevenNetCalculator(fname, file_type='torchscript') calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024'))