diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 0c9af7b7..4b8d63d7 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -65,6 +65,7 @@ SCALED_FORCE: Final[str] = 'scaled_force' PRED_STRESS: Final[str] = 'inferred_stress' +PRED_ATOMIC_VIRIAL: Final[str] = 'inferred_atomic_virial' SCALED_STRESS: Final[str] = 'scaled_stress' # very general data property for AtomGraphData diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 2f4a3d59..81ad325c 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -40,6 +40,7 @@ def __init__( enable_cueq: Optional[bool] = False, enable_flash: Optional[bool] = False, enable_oeq: Optional[bool] = False, + atomic_virial: bool = False, sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info **kwargs, ) -> None: @@ -51,7 +52,7 @@ def __init__( Name of pretrained models (7net-omni, 7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or path to the checkpoint, deployed model or the model itself file_type: str, default='checkpoint' - one of 'checkpoint' | 'torchscript' | 'model_instance' + one of 'checkpoint' | 'model_instance' device: str | torch.device, default='auto' if not given, use CUDA if available modal: str | None, default=None @@ -69,14 +70,18 @@ def __init__( if True, use OpenEquivariance to accelerate inference. sevennet_config: dict | None, default=None Not used, but can be used to carry meta information of this calculator + atomic_virial: bool, default=False + If True, request per-atom virial output (`stresses`) at runtime. """ super().__init__(**kwargs) self.sevennet_config = None + self.atomic_virial_requested = atomic_virial + self.atomic_virial_from_deploy = False if isinstance(model, pathlib.PurePath): model = str(model) - allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] + allowed_file_types = ['checkpoint', 'model_instance'] file_type = file_type.lower() if file_type not in allowed_file_types: raise ValueError(f'file_type not in {allowed_file_types}') @@ -131,6 +136,7 @@ def __init__( 'version': b'', 'dtype': b'', 'time': b'', + 'atomic_virial': b'', } model_loaded = torch.jit.load( model, _extra_files=extra_dict, map_location=self.device @@ -141,6 +147,9 @@ def __init__( sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) } self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) + self.atomic_virial_from_deploy = ( + extra_dict['atomic_virial'].decode('utf-8') == 'yes' + ) elif isinstance(model, AtomGraphSequential): if model.type_map is None: @@ -161,6 +170,13 @@ def __init__( self.sevennet_config = sevennet_config self.model = model_loaded + if isinstance(self.model, AtomGraphSequential): + force_output = self.model._modules.get('force_output') + if force_output is not None: + self.atomic_virial_from_deploy = ( + self.atomic_virial_from_deploy + or bool(getattr(force_output, 'use_atomic_virial', False)) + ) self.modal = None if isinstance(self.model, AtomGraphSequential): @@ -182,6 +198,7 @@ def __init__( 'energy', 'forces', 'stress', + 'stresses', 'energies', ] @@ -206,8 +223,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: .cpu() .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation ) - # Store results - return { + results: Dict[str, Any] = { 'free_energy': energy, 'energy': energy, 'energies': atomic_energies, @@ -215,6 +231,12 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: 'stress': stress, 'num_edges': output[KEY.EDGE_IDX].shape[1], } + if KEY.PRED_ATOMIC_VIRIAL in output: + virial = ( + output[KEY.PRED_ATOMIC_VIRIAL].detach().cpu().numpy()[:num_atoms, :] + ) + results['stresses'] = virial + return results def calculate(self, atoms=None, properties=None, system_changes=all_changes): is_ts_type = isinstance(self.model, torch_script_type) @@ -241,8 +263,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility data = data.to_dict() del data['data_info'] - - self.results = self.output_to_results(self.model(data)) + elif self.atomic_virial_requested: + force_output = self.model._modules.get('force_output') + if ( + force_output is not None + and hasattr(force_output, 'use_atomic_virial') + ): + setattr(force_output, 'use_atomic_virial', True) + + output = self.model(data) + self.results = self.output_to_results(output) class SevenNetD3Calculator(SumCalculator): diff --git a/sevenn/nn/force_output.py b/sevenn/nn/force_output.py index cd1cc816..0ee2073b 100644 --- a/sevenn/nn/force_output.py +++ b/sevenn/nn/force_output.py @@ -149,7 +149,9 @@ def __init__( data_key_energy: str = KEY.PRED_TOTAL_ENERGY, data_key_force: str = KEY.PRED_FORCE, data_key_stress: str = KEY.PRED_STRESS, + data_key_atomic_virial: str = KEY.PRED_ATOMIC_VIRIAL, data_key_cell_volume: str = KEY.CELL_VOLUME, + use_atomic_virial: bool = False, ) -> None: super().__init__() @@ -158,7 +160,9 @@ def __init__( self.key_energy = data_key_energy self.key_force = data_key_force self.key_stress = data_key_stress + self.key_atomic_virial = data_key_atomic_virial self.key_cell_volume = data_key_cell_volume + self.use_atomic_virial = use_atomic_virial self._is_batch_data = True def get_grad_key(self) -> str: @@ -206,6 +210,8 @@ def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) _edge_dst6 = broadcast(edge_idx[1], _virial, 0) _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') + if self.use_atomic_virial: + data[self.key_atomic_virial] = torch.neg(_s) if self._is_batch_data: batch = data[KEY.BATCH] # for deploy, must be defined first diff --git a/sevenn/pair_e3gnn/pair_e3gnn.cpp b/sevenn/pair_e3gnn/pair_e3gnn.cpp index 3e99b80b..9d833c83 100644 --- a/sevenn/pair_e3gnn/pair_e3gnn.cpp +++ b/sevenn/pair_e3gnn/pair_e3gnn.cpp @@ -67,7 +67,7 @@ PairE3GNN::~PairE3GNN() { memory->destroy(setflag); memory->destroy(cutsq); memory->destroy(map); - memory->destroy(elements); + // memory->destroy(elements); } } @@ -77,42 +77,29 @@ void PairE3GNN::compute(int eflag, int vflag) { This compute function is ispired/modified from stress branch of pair-nequip https://github.com/mir-group/pair_nequip */ - if (eflag || vflag) ev_setup(eflag, vflag); else evflag = vflag_fdotr = 0; - if (vflag_atom) { - error->all(FLERR, "atomic stress is not supported\n"); - } - - int nlocal = list->inum; // same as nlocal - int *ilist = list->ilist; - tagint *tag = atom->tag; - std::unordered_map tag_map; if (atom->tag_consecutive() == 0) { - for (int ii = 0; ii < nlocal; ii++) { - const int i = ilist[ii]; - int itag = tag[i]; - tag_map[itag] = ii+1; - // printf("MODIFY setting %i => %i \n",itag, tag_map[itag] ); - } - } else { - //Ordered which mappling required - for (int ii = 0; ii < nlocal; ii++) { - const int itag = ilist[ii]+1; - tag_map[itag] = ii+1; - // printf("normal setting %i => %i \n",itag, tag_map[itag] ); - } + error->all(FLERR, "Pair e3gnn requires consecutive atom IDs"); } double **x = atom->x; double **f = atom->f; int *type = atom->type; - long num_atoms[1] = {nlocal}; + int nlocal = list->inum; // same as nlocal, Why? is it different from atom->nlocal? + int *ilist = list->ilist; + int inum = list->inum; - int tag2i[nlocal]; + // tag ignore PBC + tagint *tag = atom->tag; + + std::unordered_map tag_map; + std::vector graph_index_to_i(nlocal); + + long num_atoms[1] = {nlocal}; int *numneigh = list->numneigh; // j loop cond int **firstneigh = list->firstneigh; // j list @@ -125,78 +112,64 @@ void PairE3GNN::compute(int eflag, int vflag) { } const int nedges_upper_bound = bound; - float cell[3][3]; - cell[0][0] = domain->boxhi[0] - domain->boxlo[0]; - cell[0][1] = 0.0; - cell[0][2] = 0.0; - - cell[1][0] = domain->xy; - cell[1][1] = domain->boxhi[1] - domain->boxlo[1]; - cell[1][2] = 0.0; - - cell[2][0] = domain->xz; - cell[2][1] = domain->yz; - cell[2][2] = domain->boxhi[2] - domain->boxlo[2]; - - torch::Tensor inp_cell = torch::from_blob(cell, {3, 3}, FLOAT_TYPE); - torch::Tensor inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); - - torch::Tensor inp_node_type = torch::zeros({nlocal}, INTEGER_TYPE); - torch::Tensor inp_pos = torch::zeros({nlocal, 3}); - - torch::Tensor inp_cell_volume = - torch::dot(inp_cell[0], torch::cross(inp_cell[1], inp_cell[2], 0)); - - float pbc_shift_tmp[nedges_upper_bound][3]; - - auto node_type = inp_node_type.accessor(); - auto pos = inp_pos.accessor(); + std::vector node_type; + float edge_vec[nedges_upper_bound][3]; long edge_idx_src[nedges_upper_bound]; long edge_idx_dst[nedges_upper_bound]; - int nedges = 0; - - for (int ii = 0; ii < nlocal; ii++) { + for (int ii = 0; ii < inum; ii++) { + // populate tag_map of local atoms const int i = ilist[ii]; - int itag = tag_map[tag[i]]; - tag2i[itag - 1] = i; + const int itag = tag[i]; const int itype = type[i]; - node_type[itag - 1] = map[itype]; - pos[itag - 1][0] = x[i][0]; - pos[itag - 1][1] = x[i][1]; - pos[itag - 1][2] = x[i][2]; + tag_map[itag] = ii; + graph_index_to_i[ii] = i; + node_type.push_back(map[itype]); } - for (int ii = 0; ii < nlocal; ii++) { + int nedges = 0; + // loop over neighbors, build graph + for (int ii = 0; ii < inum; ii++) { const int i = ilist[ii]; - int itag = tag_map[tag[i]]; + const int i_graph_idx = ii; const int *jlist = firstneigh[i]; const int jnum = numneigh[i]; for (int jj = 0; jj < jnum; jj++) { - int j = jlist[jj]; // atom over pbc is different atom - int jtag = tag_map[tag[j]]; // atom over pbs is same atom (it starts from 1) + int j = jlist[jj]; + const int jtag = tag[j]; j &= NEIGHMASK; - const int jtype = type[j]; + const auto found = tag_map.find(jtag); + if (found == tag_map.end()) continue; + const int j_graph_idx = found->second; + + // we have to calculate Rij to check cutoff in lammps side const double delij[3] = {x[j][0] - x[i][0], x[j][1] - x[i][1], x[j][2] - x[i][2]}; const double Rij = delij[0] * delij[0] + delij[1] * delij[1] + delij[2] * delij[2]; - if (Rij < cutoff_square) { - edge_idx_src[nedges] = itag - 1; - edge_idx_dst[nedges] = jtag - 1; - - pbc_shift_tmp[nedges][0] = x[j][0] - pos[jtag - 1][0]; - pbc_shift_tmp[nedges][1] = x[j][1] - pos[jtag - 1][1]; - pbc_shift_tmp[nedges][2] = x[j][2] - pos[jtag - 1][2]; + if (Rij < cutoff_square) { + // if given j is not inside cutoff + if (nedges >= nedges_upper_bound) { + error->all(FLERR, "nedges exceeded nedges_upper_bound"); + } + edge_idx_src[nedges] = i_graph_idx; + edge_idx_dst[nedges] = j_graph_idx; + edge_vec[nedges][0] = delij[0]; + edge_vec[nedges][1] = delij[1]; + edge_vec[nedges][2] = delij[2]; nedges++; } } // j loop end } // i loop end + // convert data to Tensor + auto inp_node_type = torch::from_blob(node_type.data(), nlocal, INTEGER_TYPE); + auto inp_num_atoms = torch::from_blob(num_atoms, {1}, INTEGER_TYPE); + auto edge_idx_src_tensor = torch::from_blob(edge_idx_src, {nedges}, INTEGER_TYPE); auto edge_idx_dst_tensor = @@ -204,66 +177,102 @@ void PairE3GNN::compute(int eflag, int vflag) { auto inp_edge_index = torch::stack({edge_idx_src_tensor, edge_idx_dst_tensor}); - // r' = r + {shift_tensor(integer vector of len 3)} @ cell_tensor - // shift_tensor = (cell_tensor)^-1^T @ (r' - r) - torch::Tensor cell_inv_tensor = - inp_cell.inverse().transpose(0, 1).unsqueeze(0).to(device); - torch::Tensor pbc_shift_tmp_tensor = - torch::from_blob(pbc_shift_tmp, {nedges, 3}, FLOAT_TYPE) - .view({nedges, 3, 1}) - .to(device); - torch::Tensor inp_cell_shift = - torch::bmm(cell_inv_tensor.expand({nedges, 3, 3}), pbc_shift_tmp_tensor) - .view({nedges, 3}); - - inp_pos.set_requires_grad(true); - - c10::Dict input_dict; + auto inp_edge_vec = torch::from_blob(edge_vec, {nedges, 3}, FLOAT_TYPE); + if (print_info) { + std::cout << " Nlocal: " << nlocal << std::endl; + std::cout << " Nedges: " << nedges << "\n" << std::endl; + } + + auto edge_vec_device = inp_edge_vec.to(device); + edge_vec_device.set_requires_grad(true); + + torch::Dict input_dict; input_dict.insert("x", inp_node_type.to(device)); - input_dict.insert("pos", inp_pos.to(device)); input_dict.insert("edge_index", inp_edge_index.to(device)); + input_dict.insert("edge_vec", edge_vec_device); input_dict.insert("num_atoms", inp_num_atoms.to(device)); - input_dict.insert("cell_lattice_vectors", inp_cell.to(device)); - input_dict.insert("cell_volume", inp_cell_volume.to(device)); - input_dict.insert("pbc_shift", inp_cell_shift); + input_dict.insert("nlocal", inp_num_atoms.to(torch::kCPU)); std::vector input(1, input_dict); auto output = model.forward(input).toGenericDict(); - torch::Tensor total_energy_tensor = - output.at("inferred_total_energy").toTensor().cpu(); - torch::Tensor force_tensor = output.at("inferred_force").toTensor().cpu(); + torch::Tensor energy_tensor = + output.at("inferred_total_energy").toTensor().squeeze(); + + // dE_dr + auto grads = torch::autograd::grad({energy_tensor}, {edge_vec_device}); + torch::Tensor dE_dr = grads[0].to(torch::kCPU); + + eng_vdwl += energy_tensor.detach().to(torch::kCPU).item(); + torch::Tensor force_tensor = torch::zeros({nlocal, 3}); + + auto _edge_idx_src_tensor = + edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}); + auto _edge_idx_dst_tensor = + edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}); + + force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum"); + force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr), + "sum"); + auto forces = force_tensor.accessor(); - eng_vdwl += total_energy_tensor.item(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - f[i][0] += forces[itag][0]; - f[i][1] += forces[itag][1]; - f[i][2] += forces[itag][2]; + for (int graph_idx = 0; graph_idx < nlocal; graph_idx++) { + int i = graph_index_to_i[graph_idx]; + f[i][0] += forces[graph_idx][0]; + f[i][1] += forces[graph_idx][1]; + f[i][2] += forces[graph_idx][2]; } + // Virial stress from edge contributions if (vflag) { - // more accurately, it is virial part of stress - torch::Tensor stress_tensor = output.at("inferred_stress").toTensor().cpu(); - auto virial_stress_tensor = stress_tensor * inp_cell_volume; - // xy yz zx order in vasp (voigt is xx yy zz yz xz xy) + auto diag = inp_edge_vec * dE_dr; + auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1); + auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2); + auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0); + std::vector voigt_list = { + diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)}; + auto voigt = torch::cat(voigt_list, 1); + + torch::Tensor per_atom_stress_tensor = torch::zeros({nlocal, 6}); + auto _edge_idx_dst6_tensor = + edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6}); + per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt, + "sum"); + + auto virial_stress_tensor = + torch::neg(torch::sum(per_atom_stress_tensor, 0)); auto virial_stress = virial_stress_tensor.accessor(); + virial[0] += virial_stress[0]; virial[1] += virial_stress[1]; virial[2] += virial_stress[2]; virial[3] += virial_stress[3]; virial[4] += virial_stress[5]; virial[5] += virial_stress[4]; + + if (vflag_atom) { + auto per_atom_stress = per_atom_stress_tensor.accessor(); + + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + vatom[i][0] += -per_atom_stress[gi][0]; + vatom[i][1] += -per_atom_stress[gi][1]; + vatom[i][2] += -per_atom_stress[gi][2]; + vatom[i][3] += -per_atom_stress[gi][3]; + vatom[i][4] += -per_atom_stress[gi][5]; + vatom[i][5] += -per_atom_stress[gi][4]; + } + } } if (eflag_atom) { torch::Tensor atomic_energy_tensor = - output.at("atomic_energy").toTensor().cpu().view({nlocal}); + output.at("atomic_energy").toTensor().to(torch::kCPU).view({nlocal}); auto atomic_energy = atomic_energy_tensor.accessor(); - for (int itag = 0; itag < nlocal; itag++) { - int i = tag2i[itag]; - eatom[i] += atomic_energy[itag]; + for (int gi = 0; gi < nlocal; gi++) { + const int i = graph_index_to_i[gi]; + eatom[i] += atomic_energy[gi]; } } diff --git a/sevenn/scripts/deploy.py b/sevenn/scripts/deploy.py index c81a9ad1..aa7d62c8 100644 --- a/sevenn/scripts/deploy.py +++ b/sevenn/scripts/deploy.py @@ -20,11 +20,7 @@ def deploy( use_flash: bool = False, use_oeq: bool = False, ) -> None: - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ForceStressOutput - cp = load_checkpoint(checkpoint) - model, config = ( cp.build_model( enable_cueq=False, @@ -35,11 +31,8 @@ def deploy( cp.config, ) - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutput() - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key + if 'force_output' in model._modules: + model.delete_module_by_key('force_output') if hasattr(model, 'eval_type_map'): setattr(model, 'eval_type_map', False) diff --git a/sevenn/scripts/inference.py b/sevenn/scripts/inference.py index c3e83428..70a8401e 100644 --- a/sevenn/scripts/inference.py +++ b/sevenn/scripts/inference.py @@ -18,6 +18,7 @@ def write_inference_csv(output_list, out): output = output.fit_dimension() output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208 output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208 + # atomic virial: keep model units (energy-like) output_list[i] = output.to_numpy_dict() per_graph_keys = [ @@ -35,6 +36,7 @@ def write_inference_csv(output_list, out): KEY.POS, KEY.FORCE, KEY.PRED_FORCE, + KEY.PRED_ATOMIC_VIRIAL, ] def unfold_dct_val(dct, keys, suffix_list=None): @@ -53,15 +55,36 @@ def unfold_dct_val(dct, keys, suffix_list=None): return res def per_atom_dct_list(dct, keys): - sfx_list = ['x', 'y', 'z'] res = [] - natoms = dct[KEY.NUM_ATOMS] - extracted = {k: dct[k] for k in keys} + natoms = int(dct[KEY.NUM_ATOMS]) for i in range(natoms): - raw = {} - raw.update({k: v[i] for k, v in extracted.items()}) - per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list) - res.append(per_atom_dct) + entry = {} + for k in keys: + if k not in dct: + continue + v = dct[k] + if isinstance(v, np.ndarray): + if v.ndim == 0: + entry[k] = v.item() + elif v.ndim == 1: + entry[k] = v[i] + elif v.ndim == 2: + d = v.shape[1] + if k in (KEY.POS, KEY.FORCE, KEY.PRED_FORCE): + sfx = ['x', 'y', 'z'] + elif k == KEY.PRED_ATOMIC_VIRIAL: + sfx = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] + else: + sfx = [str(j) for j in range(d)] + for j in range(d): + entry[f'{k}_{sfx[j]}'] = v[i, j] + else: + flat = v[i].ravel() + for j, val in enumerate(flat): + entry[f'{k}_{j}'] = val + else: + entry[k] = v + res.append(entry) return res try: diff --git a/sevenn/util.py b/sevenn/util.py index d7ce7888..20cf0464 100644 --- a/sevenn/util.py +++ b/sevenn/util.py @@ -43,6 +43,12 @@ def to_atom_graph_list(atom_graph_batch) -> List[_const.AtomGraphDataType]: if is_stress: inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS]) + inferred_virial_list = None + if KEY.PRED_ATOMIC_VIRIAL in atom_graph_batch: + inferred_virial_list = torch.split( + atom_graph_batch[KEY.PRED_ATOMIC_VIRIAL], indices + ) + for i, data in enumerate(data_list): data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i] data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i] @@ -50,6 +56,8 @@ def to_atom_graph_list(atom_graph_batch) -> List[_const.AtomGraphDataType]: # To fit with KEY.STRESS (ref) format if is_stress and inferred_stress_list is not None: data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0) + if inferred_virial_list is not None: + data[KEY.PRED_ATOMIC_VIRIAL] = inferred_virial_list[i] return data_list diff --git a/tests/lammps_tests/scripts/stress_skel.lmp b/tests/lammps_tests/scripts/stress_skel.lmp new file mode 100644 index 00000000..10b69e8c --- /dev/null +++ b/tests/lammps_tests/scripts/stress_skel.lmp @@ -0,0 +1,21 @@ + units metal + boundary __BOUNDARY__ + read_data __LMP_STCT__ + + mass * 1.0 # do not matter since we don't run MD + + pair_style __PAIR_STYLE__ + pair_coeff * * __POTENTIALS__ __ELEMENT__ + + timestep 0.002 + + compute pa all pe/atom + compute astress all stress/atom NULL pair + + thermo 1 + fix 1 all nve + thermo_style custom step tpcpu pe ke vol pxx pyy pzz pxy pxz pyz press temp + dump mydump all custom 1 __FORCE_DUMP_PATH__ id type element c_pa c_astress[1] c_astress[2] c_astress[3] c_astress[4] c_astress[5] c_astress[6] x y z fx fy fz + dump_modify mydump sort id element __ELEMENT__ + + run 0 diff --git a/tests/lammps_tests/test_lammps.py b/tests/lammps_tests/test_lammps.py index 8847feea..ec213084 100644 --- a/tests/lammps_tests/test_lammps.py +++ b/tests/lammps_tests/test_lammps.py @@ -28,6 +28,9 @@ lmp_script_path = str( (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() ) +lmp_stress_script_path = str( + (pathlib.Path(__file__).parent / 'scripts' / 'stress_skel.lmp').resolve() +) data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O @@ -106,6 +109,21 @@ def ref_modal_calculator(): return SevenNetCalculator(cp_mf_path, modal='PBE') +@pytest.fixture(scope='module') +def ref_stress_calculator(): + return SevenNetCalculator(cp_0_path, atomic_virial=True) + + +@pytest.fixture(scope='module') +def ref_7net0_stress_calculator(): + return SevenNetCalculator(cp_7net0_path, atomic_virial=True) + + +@pytest.fixture(scope='module') +def ref_modal_stress_calculator(): + return SevenNetCalculator(cp_mf_path, modal='PBE', atomic_virial=True) + + def get_model_config(): config = { 'cutoff': cutoff, @@ -176,7 +194,7 @@ def get_system(system_name, **kwargs): raise ValueError() -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6, check_atomic_stress=False): def acl(a, b, rtol=rtol, atol=atol): return np.allclose(a, b, rtol=rtol, atol=atol) @@ -190,6 +208,11 @@ def acl(a, b, rtol=rtol, atol=atol): rtol * 10, atol * 10, ) + if check_atomic_stress: + ref_atomic_virial = np.asarray(atoms1.calc.results['stresses']) + lmp_atomic_stress = np.asarray(atoms2.calc.results['atomic_stress']) + lmp_atomic_virial = -lmp_atomic_stress[:, [0, 1, 2, 3, 5, 4]] + assert acl(ref_atomic_virial, lmp_atomic_virial, rtol * 10, atol * 10) # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) @@ -219,7 +242,7 @@ def _lammps_results_to_atoms(lammps_log, force_dump): 'lmp_dump': force_dump, } # atomic energy read - latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] + latoms.calc.results['energies'] = np.ravel(latoms.arrays['c_pa']) stress = np.array( [ [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], @@ -230,10 +253,41 @@ def _lammps_results_to_atoms(lammps_log, force_dump): stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 latoms.calc.results['stress'] = stress + if 'c_astress[1]' in latoms.arrays: + atomic_stress = np.column_stack( + [ + np.asarray(latoms.arrays['c_astress[1]']), + np.asarray(latoms.arrays['c_astress[2]']), + np.asarray(latoms.arrays['c_astress[3]']), + np.asarray(latoms.arrays['c_astress[4]']), + np.asarray(latoms.arrays['c_astress[5]']), + np.asarray(latoms.arrays['c_astress[6]']), + ] + ) + latoms.calc.results['atomic_stress'] = atomic_stress / 1602.1766208 / 1000 + return latoms -def _run_lammps(atoms, pair_style, potential, wd, command, test_name): +def _run_lammps(atoms, pair_style, potential, wd, command, test_name, script_path): + def _rotate_stress(atomic_stress, rot_mat): + out = np.empty_like(atomic_stress) + for i, s in enumerate(atomic_stress): + sigma = np.array([ + [s[0], s[3], s[4]], + [s[3], s[1], s[5]], + [s[4], s[5], s[2]] + ]) + sigma = rot_mat @ sigma @ rot_mat.T + out[i] = [ + sigma[0, 0], + sigma[1, 1], + sigma[2, 2], + sigma[0, 1], + sigma[0, 2], + sigma[1, 2] + ] + return out wd = wd.resolve() pbc = atoms.get_pbc() pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) @@ -248,7 +302,7 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_stct, atoms, prismobj=prism, specorder=chem ) - with open(lmp_script_path, 'r') as f: + with open(script_path, 'r') as f: cont = f.read() lammps_log = str(wd / 'log.lammps') @@ -276,6 +330,10 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): rot_mat = prism.rot_mat results = copy.deepcopy(lmp_atoms.calc.results) + + # SinglePointCalculator does not know atomic_stress + at_stress = results.pop('atomic_stress', None) + r_force = np.dot(results['forces'], rot_mat.T) results['forces'] = r_force if 'stress' in results: @@ -287,19 +345,33 @@ def _run_lammps(atoms, pair_style, potential, wd, command, test_name): lmp_atoms.set_cell(r_cell, scale_atoms=True) lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() + if at_stress is not None: + lmp_atoms.calc.results['atomic_stress'] = _rotate_stress(at_stress, rot_mat) + return lmp_atoms def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): command = lammps_cmd - return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_script_path + ) def parallel_lammps_run( atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd ): command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' - return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) + return _run_lammps( + atoms, 'e3gnn/parallel', potential, wd, command, test_name, lmp_script_path + ) + + +def serial_stress_lammps_run(atoms, potential, wd, test_name, lammps_cmd): + command = lammps_cmd + return _run_lammps( + atoms, 'e3gnn', potential, wd, command, test_name, lmp_stress_script_path + ) def subprocess_routine(cmd, name): @@ -370,6 +442,95 @@ def test_serial_flash( assert_atoms(atoms, atoms_lammps, atol=1e-5) +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress( + system, serial_potential_path, ref_stress_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path, + wd=tmp_path, + test_name='serial stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_oeq_available(), reason='oeq not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_oeq( + system, + serial_potential_path_oeq, + ref_7net0_stress_calculator, + lammps_cmd, + tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_oeq, + wd=tmp_path, + test_name='serial oeq stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.skipif(not is_flash_available(), reason='flash not available') +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial_stress_flash( + system, + serial_potential_path_flash, + ref_7net0_stress_calculator, + lammps_cmd, + tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_potential_path_flash, + wd=tmp_path, + test_name='serial flash stress lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_7net0_stress_calculator + assert_atoms(atoms, atoms_lammps, atol=1e-5, check_atomic_stress=True) + + +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_modal_serial_stress( + system, + serial_modal_potential_path, + ref_modal_stress_calculator, + lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_stress_lammps_run( + atoms=atoms, + potential=serial_modal_potential_path, + wd=tmp_path, + test_name='modal serial lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_modal_stress_calculator + assert_atoms(atoms, atoms_lammps) + + @pytest.mark.parametrize( 'system,ncores', [ diff --git a/tests/unit_tests/test_calculator.py b/tests/unit_tests/test_calculator.py index 9b19308d..6b11b717 100644 --- a/tests/unit_tests/test_calculator.py +++ b/tests/unit_tests/test_calculator.py @@ -8,7 +8,6 @@ from sevenn.calculator import D3Calculator, SevenNetCalculator from sevenn.nn.cue_helper import is_cue_available from sevenn.nn.flash_helper import is_flash_available -from sevenn.scripts.deploy import deploy from sevenn.util import model_from_checkpoint, pretrained_name_to_path @@ -105,26 +104,6 @@ def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies']) -def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): - atoms_pbc.rattle(stdev=0.01, seed=42) - fname = str(tmp_path / '7net_0.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), fname) - - calc_script = SevenNetCalculator(fname, file_type='torchscript') - calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) - - atoms_pbc.calc = calc_cp - atoms_pbc.get_potential_energy() - res_cp = copy.copy(atoms_pbc.calc.results) - - atoms_pbc.calc = calc_script - atoms_pbc.get_potential_energy() - res_script = copy.copy(atoms_pbc.calc.results) - - for k in res_cp: - assert np.allclose(res_cp[k], res_script[k], rtol=1e-4, atol=1e-4) - - def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): atoms_pbc.rattle(stdev=0.01, seed=42) model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024'))