Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions sevenn/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@
IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies']
IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms']

SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss']
SUPPORTING_METRICS = [
'RMSE',
'ComponentRMSE',
'MAE',
'Loss',
'DiagRMSE',
'OffDiagRMSE'
]
SUPPORTING_ERROR_TYPES = [
'TotalEnergy',
'Energy',
'Force',
'Stress',
'Stress_GPa',
'BornEffectiveCharges',
'TotalLoss',
]

Expand Down Expand Up @@ -256,8 +264,10 @@ def data_defaults(config):
KEY.OPTIM_PARAM: {},
KEY.SCHEDULER: 'exponentiallr',
KEY.SCHEDULER_PARAM: {},
KEY.ENERGY_WEIGHT: 1.0,
KEY.FORCE_WEIGHT: 0.1,
KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default
KEY.BEC_WEIGHT: 1.0,
KEY.PER_EPOCH: 5,
# KEY.USE_TESTSET: False,
KEY.CONTINUE: {
Expand All @@ -272,6 +282,7 @@ def data_defaults(config):
KEY.CSV_LOG: 'log.csv',
KEY.NUM_WORKERS: 0,
KEY.IS_TRAIN_STRESS: True,
KEY.IS_TRAIN_BEC: False,
KEY.TRAIN_SHUFFLE: True,
KEY.ERROR_RECORD: [
['Energy', 'RMSE'],
Expand All @@ -288,8 +299,10 @@ def data_defaults(config):
TRAINING_CONFIG_CONDITION = {
KEY.RANDOM_SEED: int,
KEY.EPOCH: int,
KEY.ENERGY_WEIGHT: float,
KEY.FORCE_WEIGHT: float,
KEY.STRESS_WEIGHT: float,
KEY.BEC_WEIGHT: float,
KEY.USE_TESTSET: None, # Not used
KEY.NUM_WORKERS: int,
KEY.PER_EPOCH: int,
Expand All @@ -303,6 +316,7 @@ def data_defaults(config):
},
KEY.DEFAULT_MODAL: str,
KEY.IS_TRAIN_STRESS: bool,
KEY.IS_TRAIN_BEC: bool,
KEY.TRAIN_SHUFFLE: bool,
KEY.ERROR_RECORD: error_record_condition,
KEY.BEST_METRIC: str,
Expand All @@ -313,9 +327,37 @@ def data_defaults(config):


def train_defaults(config):
defaults = DEFAULT_TRAINING_CONFIG
defaults = DEFAULT_TRAINING_CONFIG.copy()
if KEY.IS_TRAIN_STRESS not in config:
config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS]
if not config[KEY.IS_TRAIN_STRESS]:
defaults.pop(KEY.STRESS_WEIGHT, None)

if KEY.IS_TRAIN_BEC not in config:
config[KEY.IS_TRAIN_BEC] = defaults[KEY.IS_TRAIN_BEC]

# Automatically add BEC metrics if enabled and default err record
if config[KEY.IS_TRAIN_BEC]:
# If the user didn't explicitly provide an ERROR_RECORD, or if they provided
# the default one, we append the BEC Diag/OffDiag metrics automatically
current_err = config.get(KEY.ERROR_RECORD, defaults[KEY.ERROR_RECORD])
if type(current_err) is list:
new_err = [list(e) for e in current_err]
if not any(e[0] == 'BornEffectiveCharges' for e in new_err):
# Insert before TotalLoss
total_loss_idx = len(new_err)
for i, e in enumerate(new_err):
if e[0] == 'TotalLoss':
total_loss_idx = i
break
new_err.insert(
total_loss_idx, ['BornEffectiveCharges', 'DiagRMSE']
)
new_err.insert(
total_loss_idx + 1, ['BornEffectiveCharges', 'OffDiagRMSE']
)
config[KEY.ERROR_RECORD] = new_err
else:
defaults.pop(KEY.BEC_WEIGHT, None)

return defaults
6 changes: 6 additions & 0 deletions sevenn/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
ENERGY: Final[str] = 'total_energy' # (1)
FORCE: Final[str] = 'force_of_atoms' # (N, 3)
STRESS: Final[str] = 'stress' # (6)
BORN_EFFECTIVE_CHARGES: Final[str] = 'born_effective_charges' # (N, 3, 3)

# This is for training, per atom scale.
SCALED_ENERGY: Final[str] = 'scaled_total_energy'
Expand All @@ -67,6 +68,8 @@
PRED_STRESS: Final[str] = 'inferred_stress'
SCALED_STRESS: Final[str] = 'scaled_stress'

PRED_BORN_EFFECTIVE_CHARGES: Final[str] = 'inferred_born_effective_charges'

# very general data property for AtomGraphData
NUM_ATOMS: Final[str] = 'num_atoms' # int
NUM_GHOSTS: Final[str] = 'num_ghosts'
Expand Down Expand Up @@ -116,14 +119,17 @@
OPTIM_PARAM = 'optim_param'
SCHEDULER = 'scheduler'
SCHEDULER_PARAM = 'scheduler_param'
ENERGY_WEIGHT = 'energy_loss_weight'
FORCE_WEIGHT = 'force_loss_weight'
STRESS_WEIGHT = 'stress_loss_weight'
BEC_WEIGHT = 'bec_loss_weight'
DEVICE = 'device'
DTYPE = 'dtype'

TRAIN_SHUFFLE = 'train_shuffle'

IS_TRAIN_STRESS = 'is_train_stress'
IS_TRAIN_BEC = 'is_train_bec'

CONTINUE = 'continue'
CHECKPOINT = 'checkpoint'
Expand Down
22 changes: 21 additions & 1 deletion sevenn/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
'forces',
'stress',
'energies',
'born_effective_charges',
]

def set_atoms(self, atoms: Atoms) -> None:
Expand All @@ -207,7 +208,7 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]:
.numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation
)
# Store results
return {
res = {
'free_energy': energy,
'energy': energy,
'energies': atomic_energies,
Expand All @@ -216,6 +217,25 @@ def output_to_results(self, output: Dict[str, torch.Tensor]) -> Dict[str, Any]:
'num_edges': output[KEY.EDGE_IDX].shape[1],
}

if KEY.PRED_BORN_EFFECTIVE_CHARGES in output:
if getattr(self, '_ct', None) is None:
from e3nn.io import CartesianTensor
self._ct = CartesianTensor('ij')
self._rtp = self._ct.reduced_tensor_products()

ct = self._ct
rtp = self._rtp
pred_bec_irreps = output[KEY.PRED_BORN_EFFECTIVE_CHARGES].detach().cpu()

# Convert 9-component irreps (1x0e+1x1e+1x2e) to 3x3 Cartesian tensors
pred_bec_cartesian = ct.to_cartesian(
pred_bec_irreps, rtp.to(pred_bec_irreps.device)
)

res['born_effective_charges'] = pred_bec_cartesian.numpy()[:num_atoms]

return res

def calculate(self, atoms=None, properties=None, system_changes=all_changes):
is_ts_type = isinstance(self.model, torch_script_type)

Expand Down
96 changes: 95 additions & 1 deletion sevenn/error_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
'coeff': 160.21766208,
'vdim': 6,
},
'BornEffectiveCharges': {
'name': 'BornEffectiveCharges',
'ref_key': KEY.BORN_EFFECTIVE_CHARGES,
'pred_key': KEY.PRED_BORN_EFFECTIVE_CHARGES,
'unit': 'e',
'vdim': 9,
},
'TotalLoss': {
'name': 'TotalLoss',
'unit': None,
Expand Down Expand Up @@ -127,6 +134,19 @@ def __init__(
self.ignore_unlabeled = ignore_unlabeled
self.value = AverageNumber()

self.is_bec = (
self.ref_key == KEY.BORN_EFFECTIVE_CHARGES
and self.pred_key == KEY.PRED_BORN_EFFECTIVE_CHARGES
)

def _get_cartesian_tensor(self) -> Any:
if getattr(self, '_ct', None) is None:
import e3nn.io
from e3nn.io import CartesianTensor
self._ct = CartesianTensor('ij')
self._rtp = self._ct.reduced_tensor_products()
return self._ct, self._rtp

def update(self, output: 'AtomGraphData') -> None:
raise NotImplementedError

Expand All @@ -135,13 +155,28 @@ def _retrieve(
) -> Tuple[torch.Tensor, torch.Tensor]:
y_ref = output[self.ref_key] * self.coeff
y_pred = output[self.pred_key] * self.coeff

# If BornEffectiveCharges, convert irreps (pred) to cartesian
if self.is_bec:
ct, rtp = self._get_cartesian_tensor()
if y_pred.shape[-1] == 9:
y_pred = ct.to_cartesian(y_pred, rtp.to(y_pred.device))
y_pred = y_pred.view(-1, 9)
if y_ref.shape[-1] == 3 and y_ref.dim() == 3:
y_ref = y_ref.view(-1, 9)

if self.per_atom:
assert y_ref.dim() == 1 and y_pred.dim() == 1
natoms = output[KEY.NUM_ATOMS]
y_ref = y_ref / natoms
y_pred = y_pred / natoms
if self.ignore_unlabeled:
unlabelled_idx = torch.isnan(y_ref)
if y_ref.dim() > 1:
unlabelled_idx = (
torch.isnan(y_ref).view(y_ref.shape[0], -1).any(dim=1)
)
else:
unlabelled_idx = torch.isnan(y_ref)
y_ref = y_ref[~unlabelled_idx]
y_pred = y_pred[~unlabelled_idx]
return y_ref, y_pred
Expand All @@ -165,6 +200,63 @@ def __str__(self):
return f'{self.key_str()}: {self.value.get():.6f}'


class BECDiagRMSError(ErrorMetric):
"""
Computes RMSE strictly on the diagonal elements of a
3x3 Born Effective Charge tensor.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._se = torch.nn.MSELoss(reduction='none')

def update(self, output: 'AtomGraphData') -> None:
y_ref, y_pred = self._retrieve(output)
if len(y_ref) == 0:
return
# Assumes y_ref and y_pred are flattened N*9 arrays, reshape to N, 3, 3
y_ref = y_ref.view(-1, 3, 3)
y_pred = y_pred.view(-1, 3, 3)

diag_idx = torch.arange(3)
y_ref_diag = y_ref[:, diag_idx, diag_idx].reshape(-1)
y_pred_diag = y_pred[:, diag_idx, diag_idx].reshape(-1)

se = self._se(y_ref_diag, y_pred_diag)
self.value.update(se)

def get(self) -> float:
return self.value.get() ** 0.5


class BECOffDiagRMSError(ErrorMetric):
"""
Computes RMSE strictly on the off-diagonal elements of a
3x3 Born Effective Charge tensor.
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._se = torch.nn.MSELoss(reduction='none')

def update(self, output: 'AtomGraphData') -> None:
y_ref, y_pred = self._retrieve(output)
if len(y_ref) == 0:
return
# Assumes y_ref and y_pred are flattened N*9 arrays, reshape to N, 3, 3
y_ref = y_ref.view(-1, 3, 3)
y_pred = y_pred.view(-1, 3, 3)

# Create mask for off-diagonal elements
mask = ~torch.eye(3, dtype=torch.bool, device=y_ref.device)
y_ref_off = y_ref[:, mask].reshape(-1)
y_pred_off = y_pred[:, mask].reshape(-1)

se = self._se(y_ref_off, y_pred_off)
self.value.update(se)

def get(self) -> float:
return self.value.get() ** 0.5


class RMSError(ErrorMetric):
"""
Vector squared error
Expand Down Expand Up @@ -317,6 +409,8 @@ class ErrorRecorder:
'ComponentRMSE': ComponentRMSError,
'MAE': MAError,
'Loss': LossError,
'DiagRMSE': BECDiagRMSError,
'OffDiagRMSE': BECOffDiagRMSError,
}

def __init__(self, metrics: List[ErrorMetric]) -> None:
Expand Down
25 changes: 23 additions & 2 deletions sevenn/model_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict
from typing import Any, Dict, List, Literal, Tuple, Type, Union, overload

from e3nn.io import CartesianTensor
from e3nn.o3 import Irreps

import sevenn._const as _const
Expand Down Expand Up @@ -567,8 +568,14 @@ def build_E3_equivariant_model(
parity_mode = 'full'
fix_multiplicity = False
if t == num_convolution_layer - 1:
lmax_node = 0
parity_mode = 'even'
# If training BEC, we need vectors/tensors to survive the last layer
if config.get(KEY.IS_TRAIN_BEC, False):
# We need at least L=1 and L=2 for vectors and tensors.
lmax_node = max(lmax_node, 2)
parity_mode = 'full'
else:
lmax_node = 0
parity_mode = 'even'
# TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out
irreps_out = (
util.infer_irreps_out(
Expand Down Expand Up @@ -599,6 +606,20 @@ def build_E3_equivariant_model(
layers.update(interaction_builder(**param_interaction_block))
irreps_x = irreps_out

if config.get(KEY.IS_TRAIN_BEC, False):
irreps_in_bec = irreps_x
layers.update(
{
'predict_bec': IrrepsLinear(
irreps_in_bec,
Irreps('1x0e+1x1e+1x2e'),
data_key_in=KEY.NODE_FEATURE,
data_key_out=KEY.PRED_BORN_EFFECTIVE_CHARGES,
biases=config[KEY.USE_BIAS_IN_LINEAR],
)
}
)

layers.update(init_feature_reduce(config, irreps_x)) # type: ignore

layers.update(
Expand Down
Loading
Loading