diff --git a/sevenn/_const.py b/sevenn/_const.py index 6dc45589..9c4a3c26 100644 --- a/sevenn/_const.py +++ b/sevenn/_const.py @@ -18,13 +18,21 @@ IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies'] IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms'] -SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss'] +SUPPORTING_METRICS = [ + 'RMSE', + 'ComponentRMSE', + 'MAE', + 'Loss', + 'DiagRMSE', + 'OffDiagRMSE' +] SUPPORTING_ERROR_TYPES = [ 'TotalEnergy', 'Energy', 'Force', 'Stress', 'Stress_GPa', + 'BornEffectiveCharges', 'TotalLoss', ] @@ -256,8 +264,10 @@ def data_defaults(config): KEY.OPTIM_PARAM: {}, KEY.SCHEDULER: 'exponentiallr', KEY.SCHEDULER_PARAM: {}, + KEY.ENERGY_WEIGHT: 1.0, KEY.FORCE_WEIGHT: 0.1, KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default + KEY.BEC_WEIGHT: 1.0, KEY.PER_EPOCH: 5, # KEY.USE_TESTSET: False, KEY.CONTINUE: { @@ -272,6 +282,7 @@ def data_defaults(config): KEY.CSV_LOG: 'log.csv', KEY.NUM_WORKERS: 0, KEY.IS_TRAIN_STRESS: True, + KEY.IS_TRAIN_BEC: False, KEY.TRAIN_SHUFFLE: True, KEY.ERROR_RECORD: [ ['Energy', 'RMSE'], @@ -288,8 +299,10 @@ def data_defaults(config): TRAINING_CONFIG_CONDITION = { KEY.RANDOM_SEED: int, KEY.EPOCH: int, + KEY.ENERGY_WEIGHT: float, KEY.FORCE_WEIGHT: float, KEY.STRESS_WEIGHT: float, + KEY.BEC_WEIGHT: float, KEY.USE_TESTSET: None, # Not used KEY.NUM_WORKERS: int, KEY.PER_EPOCH: int, @@ -303,6 +316,7 @@ def data_defaults(config): }, KEY.DEFAULT_MODAL: str, KEY.IS_TRAIN_STRESS: bool, + KEY.IS_TRAIN_BEC: bool, KEY.TRAIN_SHUFFLE: bool, KEY.ERROR_RECORD: error_record_condition, KEY.BEST_METRIC: str, @@ -313,9 +327,37 @@ def data_defaults(config): def train_defaults(config): - defaults = DEFAULT_TRAINING_CONFIG + defaults = DEFAULT_TRAINING_CONFIG.copy() if KEY.IS_TRAIN_STRESS not in config: config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS] if not config[KEY.IS_TRAIN_STRESS]: defaults.pop(KEY.STRESS_WEIGHT, None) + + if KEY.IS_TRAIN_BEC not in config: + config[KEY.IS_TRAIN_BEC] = defaults[KEY.IS_TRAIN_BEC] + + # Automatically add BEC metrics if enabled and default err record + if config[KEY.IS_TRAIN_BEC]: + # If the user didn't explicitly provide an ERROR_RECORD, or if they provided + # the default one, we append the BEC Diag/OffDiag metrics automatically + current_err = config.get(KEY.ERROR_RECORD, defaults[KEY.ERROR_RECORD]) + if type(current_err) is list: + new_err = [list(e) for e in current_err] + if not any(e[0] == 'BornEffectiveCharges' for e in new_err): + # Insert before TotalLoss + total_loss_idx = len(new_err) + for i, e in enumerate(new_err): + if e[0] == 'TotalLoss': + total_loss_idx = i + break + new_err.insert( + total_loss_idx, ['BornEffectiveCharges', 'DiagRMSE'] + ) + new_err.insert( + total_loss_idx + 1, ['BornEffectiveCharges', 'OffDiagRMSE'] + ) + config[KEY.ERROR_RECORD] = new_err + else: + defaults.pop(KEY.BEC_WEIGHT, None) + return defaults diff --git a/sevenn/_keys.py b/sevenn/_keys.py index 0c9af7b7..83587c8c 100644 --- a/sevenn/_keys.py +++ b/sevenn/_keys.py @@ -49,6 +49,7 @@ ENERGY: Final[str] = 'total_energy' # (1) FORCE: Final[str] = 'force_of_atoms' # (N, 3) STRESS: Final[str] = 'stress' # (6) +BORN_EFFECTIVE_CHARGES: Final[str] = 'born_effective_charges' # (N, 3, 3) # This is for training, per atom scale. SCALED_ENERGY: Final[str] = 'scaled_total_energy' @@ -67,6 +68,8 @@ PRED_STRESS: Final[str] = 'inferred_stress' SCALED_STRESS: Final[str] = 'scaled_stress' +PRED_BORN_EFFECTIVE_CHARGES: Final[str] = 'inferred_born_effective_charges' + # very general data property for AtomGraphData NUM_ATOMS: Final[str] = 'num_atoms' # int NUM_GHOSTS: Final[str] = 'num_ghosts' @@ -116,14 +119,17 @@ OPTIM_PARAM = 'optim_param' SCHEDULER = 'scheduler' SCHEDULER_PARAM = 'scheduler_param' +ENERGY_WEIGHT = 'energy_loss_weight' FORCE_WEIGHT = 'force_loss_weight' STRESS_WEIGHT = 'stress_loss_weight' +BEC_WEIGHT = 'bec_loss_weight' DEVICE = 'device' DTYPE = 'dtype' TRAIN_SHUFFLE = 'train_shuffle' IS_TRAIN_STRESS = 'is_train_stress' +IS_TRAIN_BEC = 'is_train_bec' CONTINUE = 'continue' CHECKPOINT = 'checkpoint' diff --git a/sevenn/calculator.py b/sevenn/calculator.py index 2f4a3d59..535b679f 100644 --- a/sevenn/calculator.py +++ b/sevenn/calculator.py @@ -183,6 +183,7 @@ def __init__( 'forces', 'stress', 'energies', + 'born_effective_charges', ] def set_atoms(self, atoms: Atoms) -> None: @@ -207,7 +208,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation ) # Store results - return { + res = { 'free_energy': energy, 'energy': energy, 'energies': atomic_energies, @@ -216,6 +217,25 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]: 'num_edges': output[KEY.EDGE_IDX].shape[1], } + if KEY.PRED_BORN_EFFECTIVE_CHARGES in output: + if getattr(self, '_ct', None) is None: + from e3nn.io import CartesianTensor + self._ct = CartesianTensor('ij') + self._rtp = self._ct.reduced_tensor_products() + + ct = self._ct + rtp = self._rtp + pred_bec_irreps = output[KEY.PRED_BORN_EFFECTIVE_CHARGES].detach().cpu() + + # Convert 9-component irreps (1x0e+1x1e+1x2e) to 3x3 Cartesian tensors + pred_bec_cartesian = ct.to_cartesian( + pred_bec_irreps, rtp.to(pred_bec_irreps.device) + ) + + res['born_effective_charges'] = pred_bec_cartesian.numpy()[:num_atoms] + + return res + def calculate(self, atoms=None, properties=None, system_changes=all_changes): is_ts_type = isinstance(self.model, torch_script_type) diff --git a/sevenn/error_recorder.py b/sevenn/error_recorder.py index 262ea06f..e150c035 100644 --- a/sevenn/error_recorder.py +++ b/sevenn/error_recorder.py @@ -60,6 +60,13 @@ 'coeff': 160.21766208, 'vdim': 6, }, + 'BornEffectiveCharges': { + 'name': 'BornEffectiveCharges', + 'ref_key': KEY.BORN_EFFECTIVE_CHARGES, + 'pred_key': KEY.PRED_BORN_EFFECTIVE_CHARGES, + 'unit': 'e', + 'vdim': 9, + }, 'TotalLoss': { 'name': 'TotalLoss', 'unit': None, @@ -127,6 +134,19 @@ def __init__( self.ignore_unlabeled = ignore_unlabeled self.value = AverageNumber() + self.is_bec = ( + self.ref_key == KEY.BORN_EFFECTIVE_CHARGES + and self.pred_key == KEY.PRED_BORN_EFFECTIVE_CHARGES + ) + + def _get_cartesian_tensor(self) -> Any: + if getattr(self, '_ct', None) is None: + import e3nn.io + from e3nn.io import CartesianTensor + self._ct = CartesianTensor('ij') + self._rtp = self._ct.reduced_tensor_products() + return self._ct, self._rtp + def update(self, output: 'AtomGraphData') -> None: raise NotImplementedError @@ -135,13 +155,28 @@ def _retrieve( ) -> Tuple[torch.Tensor, torch.Tensor]: y_ref = output[self.ref_key] * self.coeff y_pred = output[self.pred_key] * self.coeff + + # If BornEffectiveCharges, convert irreps (pred) to cartesian + if self.is_bec: + ct, rtp = self._get_cartesian_tensor() + if y_pred.shape[-1] == 9: + y_pred = ct.to_cartesian(y_pred, rtp.to(y_pred.device)) + y_pred = y_pred.view(-1, 9) + if y_ref.shape[-1] == 3 and y_ref.dim() == 3: + y_ref = y_ref.view(-1, 9) + if self.per_atom: assert y_ref.dim() == 1 and y_pred.dim() == 1 natoms = output[KEY.NUM_ATOMS] y_ref = y_ref / natoms y_pred = y_pred / natoms if self.ignore_unlabeled: - unlabelled_idx = torch.isnan(y_ref) + if y_ref.dim() > 1: + unlabelled_idx = ( + torch.isnan(y_ref).view(y_ref.shape[0], -1).any(dim=1) + ) + else: + unlabelled_idx = torch.isnan(y_ref) y_ref = y_ref[~unlabelled_idx] y_pred = y_pred[~unlabelled_idx] return y_ref, y_pred @@ -165,6 +200,63 @@ def __str__(self): return f'{self.key_str()}: {self.value.get():.6f}' +class BECDiagRMSError(ErrorMetric): + """ + Computes RMSE strictly on the diagonal elements of a + 3x3 Born Effective Charge tensor. + """ + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._se = torch.nn.MSELoss(reduction='none') + + def update(self, output: 'AtomGraphData') -> None: + y_ref, y_pred = self._retrieve(output) + if len(y_ref) == 0: + return + # Assumes y_ref and y_pred are flattened N*9 arrays, reshape to N, 3, 3 + y_ref = y_ref.view(-1, 3, 3) + y_pred = y_pred.view(-1, 3, 3) + + diag_idx = torch.arange(3) + y_ref_diag = y_ref[:, diag_idx, diag_idx].reshape(-1) + y_pred_diag = y_pred[:, diag_idx, diag_idx].reshape(-1) + + se = self._se(y_ref_diag, y_pred_diag) + self.value.update(se) + + def get(self) -> float: + return self.value.get() ** 0.5 + + +class BECOffDiagRMSError(ErrorMetric): + """ + Computes RMSE strictly on the off-diagonal elements of a + 3x3 Born Effective Charge tensor. + """ + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._se = torch.nn.MSELoss(reduction='none') + + def update(self, output: 'AtomGraphData') -> None: + y_ref, y_pred = self._retrieve(output) + if len(y_ref) == 0: + return + # Assumes y_ref and y_pred are flattened N*9 arrays, reshape to N, 3, 3 + y_ref = y_ref.view(-1, 3, 3) + y_pred = y_pred.view(-1, 3, 3) + + # Create mask for off-diagonal elements + mask = ~torch.eye(3, dtype=torch.bool, device=y_ref.device) + y_ref_off = y_ref[:, mask].reshape(-1) + y_pred_off = y_pred[:, mask].reshape(-1) + + se = self._se(y_ref_off, y_pred_off) + self.value.update(se) + + def get(self) -> float: + return self.value.get() ** 0.5 + + class RMSError(ErrorMetric): """ Vector squared error @@ -317,6 +409,8 @@ class ErrorRecorder: 'ComponentRMSE': ComponentRMSError, 'MAE': MAError, 'Loss': LossError, + 'DiagRMSE': BECDiagRMSError, + 'OffDiagRMSE': BECOffDiagRMSError, } def __init__(self, metrics: List[ErrorMetric]) -> None: diff --git a/sevenn/model_build.py b/sevenn/model_build.py index c548c34e..f07f6986 100644 --- a/sevenn/model_build.py +++ b/sevenn/model_build.py @@ -4,6 +4,7 @@ from collections import OrderedDict from typing import Any, Dict, List, Literal, Tuple, Type, Union, overload +from e3nn.io import CartesianTensor from e3nn.o3 import Irreps import sevenn._const as _const @@ -567,8 +568,14 @@ def build_E3_equivariant_model( parity_mode = 'full' fix_multiplicity = False if t == num_convolution_layer - 1: - lmax_node = 0 - parity_mode = 'even' + # If training BEC, we need vectors/tensors to survive the last layer + if config.get(KEY.IS_TRAIN_BEC, False): + # We need at least L=1 and L=2 for vectors and tensors. + lmax_node = max(lmax_node, 2) + parity_mode = 'full' + else: + lmax_node = 0 + parity_mode = 'even' # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out irreps_out = ( util.infer_irreps_out( @@ -599,6 +606,20 @@ def build_E3_equivariant_model( layers.update(interaction_builder(**param_interaction_block)) irreps_x = irreps_out + if config.get(KEY.IS_TRAIN_BEC, False): + irreps_in_bec = irreps_x + layers.update( + { + 'predict_bec': IrrepsLinear( + irreps_in_bec, + Irreps('1x0e+1x1e+1x2e'), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ) + } + ) + layers.update(init_feature_reduce(config, irreps_x)) # type: ignore layers.update( diff --git a/sevenn/train/dataload.py b/sevenn/train/dataload.py index 545131ee..1dfc2958 100644 --- a/sevenn/train/dataload.py +++ b/sevenn/train/dataload.py @@ -160,6 +160,7 @@ def atoms_to_graph( y_energy = atoms.info['y_energy'] y_force = atoms.arrays['y_force'] y_stress = atoms.info.get('y_stress', np.full((6,), np.nan)) + y_bec = atoms.arrays.get('y_bec', np.full((len(atoms), 3, 3), np.nan)) if y_stress.shape == (3, 3): y_stress = np.array( [ @@ -178,11 +179,16 @@ def atoms_to_graph( y_energy = from_calc['energy'] y_force = from_calc['force'] y_stress = from_calc['stress'] + y_bec = from_calc['born_effective_charges'] assert y_stress.shape == (6,), 'If you see this, please raise a issue' if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()): raise ValueError('Unlabeled E or F found, set allow_unlabeled True') + if y_bec.shape == (len(atoms), 9): + y_bec = y_bec.reshape((len(atoms), 3, 3)) + assert y_bec.shape == (len(atoms), 3, 3), 'If you see this, please raise a issue' + pos = atoms.get_positions() cell = np.array(atoms.get_cell()) pbc = atoms.get_pbc() @@ -204,6 +210,7 @@ def atoms_to_graph( KEY.CELL_VOLUME: _correct_scalar(atoms.cell.volume), KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)), + KEY.BORN_EFFECTIVE_CHARGES: y_bec, } if with_shift: @@ -274,6 +281,7 @@ def _y_from_calc(atoms: ase.Atoms): 'energy': np.nan, 'force': np.full((len(atoms), 3), np.nan), 'stress': np.full((6,), np.nan), + 'born_effective_charges': np.full((len(atoms), 3, 3), np.nan), } if atoms.calc is None: @@ -294,6 +302,14 @@ def _y_from_calc(atoms: ase.Atoms): ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) except RuntimeError: pass + + try: + ret['born_effective_charges'] = atoms.calc.results.get( + 'born_effective_charges', np.full((len(atoms), 3, 3), np.nan) + ) + except AttributeError: + pass + return ret @@ -302,11 +318,12 @@ def _set_atoms_y( energy_key: Optional[str] = None, force_key: Optional[str] = None, stress_key: Optional[str] = None, + bec_key: Optional[str] = None, ) -> List[ase.Atoms]: """ Define how SevenNet reads ASE.atoms object for its y label - If energy_key, force_key, or stress_key is given, the corresponding - label is obtained from .info dict of Atoms object. These values should + If energy_key, force_key, stress_key, or bec_key is given, the corresponding + label is obtained from .info or .arrays dict of Atoms object. These values should have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress, respectively. (stress in Voigt notation) @@ -315,6 +332,7 @@ def _set_atoms_y( energy_key (str, optional): key to get energy. Defaults to None. force_key (str, optional): key to get force. Defaults to None. stress_key (str, optional): key to get stress. Defaults to None. + bec_key (str, optional): key to get born effective charges. Defaults to None. Returns: list[ase.Atoms]: list of ase.Atoms @@ -345,6 +363,13 @@ def _set_atoms_y( else: atoms.info['y_stress'] = from_calc['stress'] + if bec_key is not None: + atoms.arrays['y_bec'] = atoms.arrays.pop(bec_key) + elif 'born_effective_charges' in atoms.arrays: + atoms.arrays['y_bec'] = atoms.arrays.pop('born_effective_charges') + else: + atoms.arrays['y_bec'] = from_calc['born_effective_charges'] + return atoms_list @@ -353,6 +378,7 @@ def ase_reader( energy_key: Optional[str] = None, force_key: Optional[str] = None, stress_key: Optional[str] = None, + bec_key: Optional[str] = None, index: str = ':', **kwargs, ) -> List[ase.Atoms]: @@ -363,7 +389,7 @@ def ase_reader( if not isinstance(atoms_list, list): atoms_list = [atoms_list] - return _set_atoms_y(atoms_list, energy_key, force_key, stress_key) + return _set_atoms_y(atoms_list, energy_key, force_key, stress_key, bec_key) # Reader diff --git a/sevenn/train/graph_dataset.py b/sevenn/train/graph_dataset.py index 224e6e22..52552849 100644 --- a/sevenn/train/graph_dataset.py +++ b/sevenn/train/graph_dataset.py @@ -65,7 +65,13 @@ def _run_stat( """ Loop over dataset and init any statistics might need """ - y_keys = y_keys or [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] + y_keys = y_keys or [ + KEY.ENERGY, + KEY.PER_ATOM_ENERGY, + KEY.FORCE, + KEY.STRESS, + KEY.BORN_EFFECTIVE_CHARGES, + ] n_neigh = [] natoms_counter = Counter() composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT)) @@ -79,14 +85,17 @@ def _run_stat( composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT) n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1]) for y, dct in stats.items(): - dct['_array'].append( - graph[y].reshape( - -1, + if y in graph and graph[y] is not None: + dct['_array'].append( + graph[y].reshape( + -1, + ) ) - ) stats.update({'num_neighbor': {'_array': n_neigh}}) for y, dct in stats.items(): + if len(dct['_array']) == 0: + continue array = torch.cat(dct['_array']) if array.dtype == torch.int64: # because of n_neigh array = array.to(torch.float) diff --git a/sevenn/train/loss.py b/sevenn/train/loss.py index a6f8a769..c50597ff 100644 --- a/sevenn/train/loss.py +++ b/sevenn/train/loss.py @@ -201,6 +201,77 @@ def _preprocess( return pred, ref, w_tensor +class BECLoss(LossDefinition): + """ + Loss for Born Effective Charges + """ + + def __init__( + self, + name: str = 'BornEffectiveCharges', + unit: str = 'e', + criterion: Optional[Callable] = None, + ref_key: str = KEY.BORN_EFFECTIVE_CHARGES, + pred_key: str = KEY.PRED_BORN_EFFECTIVE_CHARGES, + **kwargs, + ) -> None: + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _get_cartesian_tensor(self) -> Any: + if getattr(self, '_ct', None) is None: + import e3nn.io + from e3nn.io import CartesianTensor + self._ct = CartesianTensor('ij') + self._rtp = self._ct.reduced_tensor_products() + return self._ct, self._rtp + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + + # pred is 9 components (1x0e+1x1e+1x2e irreps format) + pred = batch_data[self.pred_key] + + # ref is Cartesian tensor 3x3 format (or 9 flat cartesian) + ref_cartesian = batch_data[self.ref_key] + if ref_cartesian.shape[-1] == 9 and ref_cartesian.dim() == 2: + ref_cartesian = ref_cartesian.reshape(-1, 3, 3) + + # Convert true cartesian to irreps format (N, 9) + ct, rtp = self._get_cartesian_tensor() + ref_irreps = ct.from_cartesian(ref_cartesian, rtp.to(ref_cartesian.device)) + + pred = torch.reshape(pred, (-1,)) + ref = torch.reshape(ref_irreps, (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = weight[batch_data[KEY.BATCH]] + w_tensor = torch.repeat_interleave(w_tensor, 9) + + return pred, ref, w_tensor + + def get_loss(self, batch_data: Dict[str, Any], model: Optional[Callable] = None): + """ + Function that return scalar. + Overridden for BECLoss to compensate for 9-component flattening. + Flattening divides the mean loss by N*9 instead of N. We multiply by 9 + to restore per-atom loss scaling, ensuring consistent gradient magnitudes. + """ + loss = super().get_loss(batch_data, model) + return loss * 9.0 + + def get_loss_functions_from_config( config: Dict[str, Any], ) -> List[Tuple[LossDefinition, float]]: @@ -218,10 +289,14 @@ def get_loss_functions_from_config( commons = {'use_weight': use_weight} - loss_functions.append((PerAtomEnergyLoss(**commons), 1.0)) + loss_functions.append( + (PerAtomEnergyLoss(**commons), config.get(KEY.ENERGY_WEIGHT, 1.0)) + ) loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT])) if config[KEY.IS_TRAIN_STRESS]: loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT])) + if config.get(KEY.IS_TRAIN_BEC, False): + loss_functions.append((BECLoss(**commons), config[KEY.BEC_WEIGHT])) for loss_function, _ in loss_functions: # why do these? if loss_function.criterion is None: diff --git a/sevenn/train/optim.py b/sevenn/train/optim.py index 10e75790..013d03c2 100644 --- a/sevenn/train/optim.py +++ b/sevenn/train/optim.py @@ -20,4 +20,4 @@ 'linearlr': scheduler.LinearLR, } -loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss} +loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss, 'mae': nn.L1Loss} diff --git a/tests/unit_tests/test_data.py b/tests/unit_tests/test_data.py index 0b3e7b6e..e2ec2df5 100644 --- a/tests/unit_tests/test_data.py +++ b/tests/unit_tests/test_data.py @@ -208,12 +208,12 @@ def test_graph_build(): for k in g1.keys(): if not isinstance(g1[k], torch.Tensor): continue - if k == 'stress': # TODO: robust way to test it - assert torch.allclose(g1[k], g2[k]) or ( + if k in ['stress', 'born_effective_charges']: # TODO: robust test + assert torch.allclose(g1[k], g2[k], equal_nan=True) or ( torch.isnan(g1[k]).all() == torch.isnan(g2[k]).all() ) else: - assert torch.allclose(g1[k], g2[k]) + assert torch.allclose(g1[k], g2[k], equal_nan=True) @pytest.fixture(scope='module')