diff --git a/.github/workflows/architecture-tests.yml b/.github/workflows/architecture-tests.yml index b9fcc68686..6a602fa980 100644 --- a/.github/workflows/architecture-tests.yml +++ b/.github/workflows/architecture-tests.yml @@ -20,7 +20,8 @@ jobs: - mace - nanopet - pet - - soap-bpnn + - phace + - soap-bpnn runs-on: ubuntu-22.04 diff --git a/CODEOWNERS b/CODEOWNERS index 239b7ce203..d61794cb09 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -6,6 +6,7 @@ **/pet @abmazitov **/gap @DavideTisi **/nanopet @frostedoyster +**/phace @frostedoyster **/llpr @frostedoyster @SanggyuChong **/flashmd @johannes-spies @frostedoyster **/classifier @frostedoyster diff --git a/examples/1-advanced/03-fitting-generic-targets.py b/examples/1-advanced/03-fitting-generic-targets.py index ddda39e41d..10691d78cd 100644 --- a/examples/1-advanced/03-fitting-generic-targets.py +++ b/examples/1-advanced/03-fitting-generic-targets.py @@ -25,7 +25,7 @@ - Energy, forces, stress/virial - Yes - Yes - - No + - Only rank-1 (vectors) * - GAP - Energy, forces - No @@ -46,6 +46,11 @@ - Yes - Yes - Only with ``rank=1`` (vectors) + * - PhACE + - Energy, forces, stress/virial + - Yes + - Yes + - Only rank-1 (vectors) Preparing generic targets for reading by metatrain -------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index c36d5a3b1e..c5e53c4c83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,11 @@ gap = [ "skmatter", "scipy", ] +phace = [ + "physical_basis", + "wigners", + "opt-einsum", +] llpr = [] classifier = [] mace = [ @@ -171,8 +176,10 @@ filterwarnings = [ "ignore:No libgomp shared library found in 'sphericart_torch.libs'.", # Multi-threaded tests clash with multi-process data-loading "ignore:This process \\(pid=\\d+\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child.:DeprecationWarning", + # Initialization of tensors of zero size + "ignore:Initializing zero-element tensors is a no-op:UserWarning", # MACE warning with newer versions of pytorch (because they use e3nn==0.4.4) - "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning" + "ignore:Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed:UserWarning", ] addopts = ["-p", "mtt_plugin"] pythonpath = "src/metatrain/utils/testing" diff --git a/src/metatrain/deprecated/nanopet/trainer.py b/src/metatrain/deprecated/nanopet/trainer.py index 579d3cecf3..7327643f54 100644 --- a/src/metatrain/deprecated/nanopet/trainer.py +++ b/src/metatrain/deprecated/nanopet/trainer.py @@ -94,7 +94,7 @@ def train( # Move the model to the device and dtype: model.to(device=device, dtype=dtype) - # The additive models of the SOAP-BPNN are always in float64 (to avoid + # The additive models of NanoPET are always in float64 (to avoid # numerical errors in the composition weights, which can be very large). for additive_model in model.additive_models: additive_model.to(dtype=torch.float64) diff --git a/src/metatrain/experimental/phace/__init__.py b/src/metatrain/experimental/phace/__init__.py new file mode 100644 index 0000000000..ab340614e6 --- /dev/null +++ b/src/metatrain/experimental/phace/__init__.py @@ -0,0 +1,14 @@ +from .model import PhACE +from .trainer import Trainer + + +__model__ = PhACE +__trainer__ = Trainer + +__authors__ = [ + ("Filippo Bigi ", "@frostedoyster"), +] + +__maintainers__ = [ + ("Filippo Bigi ", "@frostedoyster"), +] diff --git a/src/metatrain/experimental/phace/checkpoints.py b/src/metatrain/experimental/phace/checkpoints.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/phace/documentation.py b/src/metatrain/experimental/phace/documentation.py new file mode 100644 index 0000000000..8ab5b6917b --- /dev/null +++ b/src/metatrain/experimental/phace/documentation.py @@ -0,0 +1,280 @@ +""" +PhACE +===== + +PhACE is a physics-inspired equivariant neural network architecture. Compared to, for +example, MACE and GRACE, it uses a geometrically motivated basis and a fast and +elegant tensor product implementation. The tensor product used in PhACE leverages a +equivariant representation that differs from the typical spherical one. You can read +more about it here: https://pubs.acs.org/doi/10.1021/acs.jpclett.4c02376. + +{{SECTION_INSTALLATION}} + +{{SECTION_DEFAULT_HYPERS}} + +Tuning hyperparameters +---------------------- + +The default hyperparameters above will work well in most cases, but they +may not be optimal for your specific use case. There is good number of +parameters to tune, both for the +:ref:`model ` and the +:ref:`trainer `. Here, we provide a +**list of the parameters that are in general the most important** (in decreasing order +of importance) for the PhACE architecture: + +.. container:: mtt-hypers-remove-classname + + .. autoattribute:: {{model_hypers_path}}.radial_basis + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_element_channels + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.num_epochs + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.batch_size + :no-index: + + .. autoattribute:: {{model_hypers_path}}.num_message_passing_layers + :no-index: + + .. autoattribute:: {{trainer_hypers_path}}.learning_rate + :no-index: + + .. autoattribute:: {{model_hypers_path}}.cutoff + :no-index: + + .. autoattribute:: {{model_hypers_path}}.force_rectangular + :no-index: + + .. autoattribute:: {{model_hypers_path}}.spherical_linear_layers + :no-index: +""" + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from metatrain.utils.additive import FixedCompositionWeights +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.loss import LossSpecification +from metatrain.utils.scaler import FixedScalerWeights + + +class RadialBasisHypers(TypedDict): + """Hyperparameter concerning the radial basis functions used in the model.""" + + max_eigenvalue: float = 25.0 + """Maximum eigenvalue for the radial basis.""" + + scale: float = 0.7 + """Scaling factor for the radial basis.""" + + optimizable_lengthscales: bool = False + """Whether the length scales in the radial basis are optimizable.""" + + +########################### +# MODEL HYPERPARAMETERS # +########################### + + +class ModelHypers(TypedDict): + """Hyperparameters for the experimental.phace model.""" + + max_correlation_order_per_layer: int = 3 + """Maximum correlation order per layer.""" + + num_message_passing_layers: int = 2 + """Number of message passing layers. + + Increasing this value might increase the accuracy of the model (especially on + larger datasets), at the expense of computational efficiency. + """ + + cutoff: float = 5.0 + """Cutoff radius for neighbor search. + + This should be set to a value after which most of the interactions + between atoms is expected to be negligible. A lower cutoff will lead + to faster models. + """ + + cutoff_width: float = 1.0 + """Width of the cutoff smoothing function.""" + + num_element_channels: int = 128 + """Number of channels per element. + + This determines the size of the embedding used to encode the atomic species, and it + increases or decreases the size of the internal features used in the model. + """ + + force_rectangular: bool = False + """Makes the number of channels per irrep the same. + + This might improve accuracy with a limited increase in computational cost. + """ + + spherical_linear_layers: bool = False + """Whether to perform linear layers in the spherical representation.""" + + radial_basis: RadialBasisHypers = init_with_defaults(RadialBasisHypers) + """Hyperparameters for the radial basis functions. + + Raising``max_eigenvalue`` from its default will increase the number of spherical + irreducible representations (irreps) used in the model, which can improve accuracy + at the cost of computational efficiency. Increasing this value will also increase + the number of radial basis functions (and therefore internal features) used for each + irrep. + """ + + nu_scaling: float = 0.1 + """Scaling for the nu term.""" + + mp_scaling: float = 0.1 + """Scaling for message passing.""" + + overall_scaling: float = 1.0 + """Overall scaling factor.""" + + disable_nu_0: bool = True + """Whether to disable nu=0.""" + + use_sphericart: bool = False + """Whether to use spherical Cartesian coordinates.""" + + head_num_layers: int = 1 + """Number of layers in the head.""" + + heads: dict[str, Literal["linear", "mlp"]] = {} + """Heads to use in the model, with options being "linear" or "mlp".""" + + zbl: bool = False + """Whether to use the ZBL potential in the model.""" + + +############################## +# TRAINER HYPERPARAMETERS # +############################## + + +class TrainerHypers(TypedDict): + """Hyperparameters for training the experimental.phace model.""" + + compile: bool = True + """Whether to use `torch.compile` during training. + + This can lead to significant speedups, but it will cause a compilation step at the + beginning of training which might take up to 5-10 minutes, mainly depending on + ``max_eigenvalue``. + """ + + distributed: bool = False + """Whether to use distributed training.""" + + distributed_port: int = 39591 + """Port for DDP communication.""" + + batch_size: int = 8 + """Batch size for training. + + Decrease this value if you run into out-of-memory errors during training. You can + try to increase it if your structures are very small (less than 20 atoms) and you + have a good GPU. + """ + + num_epochs: int = 1000 + """Number of epochs to train the model. + + A larger number of epochs might lead to better accuracy. In general, if you see + that the validation metrics are not much worse than the training ones at the end of + training, it might be a good idea to increase this value. + """ + + learning_rate: float = 0.01 + """Learning rate for the optimizer. + + You can try to increase this value (e.g., to 0.02 or 0.03) if training is very + slow or decrease it (e.g., to 0.005 or less) if you see that training explodes in + the first few epochs. + """ + + warmup_fraction: float = 0.01 + """Fraction of training steps for learning rate warmup.""" + + gradient_clipping: Optional[float] = None + """Gradient clipping value. If None, no clipping is applied.""" + + log_interval: int = 1 + """Interval to log metrics during training.""" + + checkpoint_interval: int = 25 + """Interval to save model checkpoints.""" + + scale_targets: bool = True + """Whether to scale targets during training.""" + + atomic_baseline: FixedCompositionWeights = {} + """The baselines for each target. + + By default, ``metatrain`` will fit a linear model (:class:`CompositionModel + `) to compute the + least squares baseline for each atomic species for each target. + + However, this hyperparameter allows you to provide your own baselines. + The value of the hyperparameter should be a dictionary where the keys are the + target names, and the values are either (1) a single baseline to be used for + all atomic types, or (2) a dictionary mapping atomic types to their baselines. + For example: + + - ``atomic_baseline: {"energy": {1: -0.5, 6: -10.0}}`` will fix the energy + baseline for hydrogen (Z=1) to -0.5 and for carbon (Z=6) to -10.0, while + fitting the baselines for the energy of all other atomic types, as well + as fitting the baselines for all other targets. + - ``atomic_baseline: {"energy": -5.0}`` will fix the energy baseline for + all atomic types to -5.0. + - ``atomic_baseline: {"mtt:dos": 0.0}`` sets the baseline for the "mtt:dos" + target to 0.0, effectively disabling the atomic baseline for that target. + + This atomic baseline is substracted from the targets during training, which + avoids the main model needing to learn atomic contributions, and likely makes + training easier. When the model is used in evaluation mode, the atomic baseline + is added on top of the model predictions automatically. + + .. note:: + This atomic baseline is a per-atom contribution. Therefore, if the property + you are predicting is a sum over all atoms (e.g., total energy), the + contribution of the atomic baseline to the total property will be the + atomic baseline multiplied by the number of atoms of that type in the + structure. + + .. note:: + If a MACE model is loaded through the ``mace_model`` hyperparameter, the + atomic baselines in the MACE model are used by default for the target + indicated in ``mace_head_target``. If you want to override them, you need + to set explicitly the baselines for that target in this hyperparameter. + """ + + fixed_scaling_weights: FixedScalerWeights = {} + """Fixed scaling weights for the model.""" + + num_workers: Optional[int] = None + """Number of workers for data loading.""" + + per_structure_targets: list[str] = [] + """List of targets to calculate per-structure losses.""" + + log_separate_blocks: bool = False + """Whether to log per-block error during training.""" + + log_mae: bool = False + """Whether to log MAE alongside RMSE during training.""" + + best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "rmse_prod" + """Metric used to select the best model checkpoint.""" + + loss: str | dict[str, LossSpecification] = "mse" + """Loss function used for training.""" diff --git a/src/metatrain/experimental/phace/model.py b/src/metatrain/experimental/phace/model.py new file mode 100644 index 0000000000..a2eaa86d88 --- /dev/null +++ b/src/metatrain/experimental/phace/model.py @@ -0,0 +1,600 @@ +import logging +import warnings +from typing import Any, Dict, List, Literal, Optional + +import metatensor.torch +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.operations._add import _add_block_block +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + NeighborListOptions, + System, +) + +from metatrain.experimental.phace.documentation import ModelHypers +from metatrain.experimental.phace.modules.base_model import ( + BaseModel, + FakeGradientModel, + GradientModel, +) +from metatrain.experimental.phace.utils import systems_to_batch +from metatrain.utils.abc import ModelInterface +from metatrain.utils.additive import ZBL, CompositionModel +from metatrain.utils.data.dataset import DatasetInfo, TargetInfo +from metatrain.utils.dtype import dtype_to_str +from metatrain.utils.metadata import merge_metadata +from metatrain.utils.scaler import Scaler + +from . import checkpoints + + +warnings.filterwarnings( + "ignore", + category=UserWarning, + message=("The TorchScript type system doesn't support instance-level annotations"), +) +warnings.filterwarnings( + "ignore", + category=UserWarning, + message=("Initializing zero-element tensors is a no-op"), +) + + +class PhACE(ModelInterface[ModelHypers]): + __checkpoint_version__ = 1 + __supported_devices__ = ["cuda", "cpu"] + __supported_dtypes__ = [torch.float64, torch.float32] + __default_metadata__ = ModelMetadata(references={}) + + component_labels: Dict[str, List[List[Labels]]] + U_dict: Dict[int, torch.Tensor] + + def __init__(self, hypers: ModelHypers, dataset_info: DatasetInfo) -> None: + super().__init__(hypers, dataset_info, self.__default_metadata__) + + self.new_outputs = list(dataset_info.targets.keys()) + self.atomic_types = sorted(dataset_info.atomic_types) + + self.cutoff_radius = float(hypers["cutoff"]) + self.dataset_info = dataset_info + self.hypers = hypers + + # machinery to trick torchscript into liking our model + base_model = BaseModel(hypers, dataset_info) + self.fake_gradient_model = FakeGradientModel(base_model) + self.gradient_model = GradientModel(base_model) + self.module = self.fake_gradient_model + + self.k_max_l = self.module.module.k_max_l + self.l_max = len(self.k_max_l) - 1 + + self.overall_scaling = hypers["overall_scaling"] + + self.outputs = { + "features": ModelOutput(unit="", per_atom=True) + } # the model is always capable of outputting the internal features + for target_name in dataset_info.targets.keys(): + # the model can always output the last-layer features for the targets + ll_features_name = ( + f"mtt::aux::{target_name.replace('mtt::', '')}_last_layer_features" + ) + self.outputs[ll_features_name] = ModelOutput(per_atom=True) + + self.key_labels: Dict[str, Labels] = {} + self.component_labels: Dict[str, List[List[Labels]]] = {} + self.property_labels: Dict[str, List[Labels]] = {} + self.head_num_layers = self.hypers["head_num_layers"] + for target_name, target_info in dataset_info.targets.items(): + self._add_output(target_name, target_info) + + self.last_layer_feature_size = self.k_max_l[0] + + # additive models: these are handled by the trainer at training + # time, and they are added to the output at evaluation time + composition_model = CompositionModel( + hypers={}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + additive_models = [composition_model] + if self.hypers["zbl"]: + additive_models.append( + ZBL( + {}, + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if ZBL.is_valid_target(target_name, target_info) + }, + ), + ) + ) + self.additive_models = torch.nn.ModuleList(additive_models) + + # scaler: this is also handled by the trainer at training time + self.scaler = Scaler(hypers={}, dataset_info=dataset_info) + + self.single_label = Labels.single() + + @torch.jit.export + def supported_outputs(self) -> Dict[str, ModelOutput]: + return self.outputs + + def restart(self, dataset_info: DatasetInfo) -> "PhACE": + # merge old and new dataset info + merged_info = self.dataset_info.union(dataset_info) + new_atomic_types = [ + at for at in merged_info.atomic_types if at not in self.atomic_types + ] + new_targets = { + key: value + for key, value in merged_info.targets.items() + if key not in self.dataset_info.targets + } + self.has_new_targets = len(new_targets) > 0 + + if len(new_atomic_types) > 0: + raise ValueError( + f"New atomic types found in the dataset: {new_atomic_types}. " + "The PhACE model does not support adding new atomic types." + ) + + # register new outputs as new last layers + for target_name, target in new_targets.items(): + self._add_output(target_name, target) + + self.dataset_info = merged_info + + # restart the composition and scaler models + self.additive_models[0].restart( + dataset_info=DatasetInfo( + length_unit=dataset_info.length_unit, + atomic_types=self.atomic_types, + targets={ + target_name: target_info + for target_name, target_info in dataset_info.targets.items() + if CompositionModel.is_valid_target(target_name, target_info) + }, + ), + ) + self.scaler.restart(dataset_info) + + return self + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + # transfer labels, if needed + device = systems[0].device + if self.single_label.values.device != device: + self.single_label = self.single_label.to(device) + self.key_labels = { + output_name: label.to(device) + for output_name, label in self.key_labels.items() + } + self.component_labels = { + output_name: [ + [label.to(device) for label in component] + for component in components + ] + for output_name, components in self.component_labels.items() + } + self.property_labels = { + output_name: [label.to(device) for label in labels] + for output_name, labels in self.property_labels.items() + } + + # Convert systems to batch format + neighbor_list_options = self.requested_neighbor_lists()[0] # there is only one + batch = systems_to_batch(systems, neighbor_list_options) + + # compute sample labels from batch + samples_values = torch.stack( + [batch["structure_centers"], batch["centers"]], dim=1 + ) + samples = metatensor.torch.Labels( + names=["system", "atom"], + values=samples_values, + ) + + outputs_with_gradients: List[str] = [] + for output_name, output_info in outputs.items(): + if len(output_info.explicit_gradients) > 0: + outputs_with_gradients.append(output_name) + + predictions = self.module(batch, outputs_with_gradients) + + return_dict: Dict[str, TensorMap] = {} + + # output the features, if requested: + if "features" in outputs: + # only a single features block is supported by metatomic, we choose L=0 + features_tensor = predictions["features"][0].squeeze(1) + features = TensorMap( + keys=self.single_label, + blocks=[ + TensorBlock( + values=features_tensor, + samples=samples, + components=[], + properties=Labels( + names=["feature"], + values=torch.arange(features_tensor.shape[-1]).unsqueeze( + -1 + ), + ), + ) + ], + ) + if selected_atoms is not None: + features = metatensor.torch.slice( + features, axis="samples", selection=selected_atoms + ) + if outputs["features"].per_atom: + return_dict["features"] = features + else: + return_dict["features"] = metatensor.torch.sum_over_samples( + features, ["atom"] + ) + + # output the last-layer features for the outputs, if requested: + for output_name in outputs.keys(): + if not ( + output_name.startswith("mtt::aux::") + and output_name.endswith("_last_layer_features") + ): + continue + base_name = output_name.replace("mtt::aux::", "").replace( + "_last_layer_features", "" + ) + # the corresponding output could be base_name or mtt::base_name + if f"mtt::{base_name}" in self.outputs: + base_name = f"mtt::{base_name}" + + last_layer_features_as_dict_of_tensors = predictions[f"{base_name}__llf"] + return_dict[output_name] = TensorMap( + keys=Labels( + names=["o3_lambda"], + values=torch.arange(self.l_max + 1, device=device).unsqueeze(-1), + ), + blocks=[ + TensorBlock( + values=t, + samples=samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.arange(-l, l + 1, device=device).unsqueeze( + -1 + ), + ) + ], + properties=Labels( + names=["feature"], + values=torch.arange(t.shape[-1], device=device).unsqueeze( + -1 + ), + ), + ) + for l, t in last_layer_features_as_dict_of_tensors.items() # noqa: E741 + ], + ) + if selected_atoms is not None: + return_dict[output_name] = metatensor.torch.slice( + return_dict[output_name], axis="samples", selection=selected_atoms + ) + if not outputs[output_name].per_atom: + return_dict[output_name] = metatensor.torch.sum_over_samples( + return_dict[output_name], ["atom"] + ) + + # remaining outputs (main outputs) + for output_name in outputs.keys(): + if output_name == "features" or output_name.startswith("mtt::aux::"): + continue + output_as_tensor_dict = predictions[output_name] + return_dict[output_name] = TensorMap( + keys=self.key_labels[output_name], + blocks=[ + TensorBlock( + values=( + output_as_tensor_dict[(len(c[0]) - 1) // 2] + if len(c) > 0 + else output_as_tensor_dict[0].squeeze(1) + ), + samples=samples, + components=c, + properties=p, + ) + for c, p in zip( + self.component_labels[output_name], + self.property_labels[output_name], + strict=True, + ) + ], + ) + # Handle Cartesian rank-1 outputs (e.g. direct forces) + if len(self.component_labels[output_name]) == 1: + if len(self.component_labels[output_name][0]) == 1: + if self.component_labels[output_name][0][0].names == ["xyz"]: + return_dict[output_name].block().values[:] = ( + return_dict[output_name].block().values[:, [2, 0, 1]] + ) + if selected_atoms is not None: + return_dict[output_name] = metatensor.torch.slice( + return_dict[output_name], axis="samples", selection=selected_atoms + ) + if not outputs[output_name].per_atom: + return_dict[output_name] = metatensor.torch.sum_over_samples( + return_dict[output_name], ["atom"] + ) + if len(outputs[output_name].explicit_gradients) == 0: + continue + original_block = return_dict[output_name].block() + block = TensorBlock( + values=original_block.values, + samples=original_block.samples, + components=original_block.components, + properties=original_block.properties, + ) + device = block.values.device + for gradient_name in outputs[output_name].explicit_gradients: + if gradient_name == "positions": + samples = Labels( + names=["sample", "atom"], + values=torch.stack( + [ + torch.concatenate( + [ + torch.tensor([i] * len(system), device=device) + for i, system in enumerate(systems) + ] + ), + torch.concatenate( + [ + torch.arange(len(system), device=device) + for system in systems + ] + ), + ], + dim=1, + ), + assume_unique=True, + ) + components = [ + Labels( + names=["xyz"], + values=torch.tensor([[0], [1], [2]], device=device), + ) + ] + gradient_tensor = predictions[f"{output_name}__pos"][-1] + elif gradient_name == "strain": + samples = Labels( + names=["sample"], + values=torch.arange(len(systems), device=device).unsqueeze(-1), + assume_unique=True, + ) + components = [ + Labels( + names=["xyz_1"], + values=torch.tensor([[0], [1], [2]], device=device), + ), + Labels( + names=["xyz_2"], + values=torch.tensor([[0], [1], [2]], device=device), + ), + ] + gradient_tensor = predictions[f"{output_name}__str"][-1] + else: + raise ValueError( + f"Unsupported explicit gradient request: {gradient_name}" + ) + block.add_gradient( + gradient_name, + TensorBlock( + values=gradient_tensor.unsqueeze(-1), + samples=samples.to(gradient_tensor.device), + components=components, + properties=Labels("energy", torch.tensor([[0]], device=device)), + ), + ) + return_dict[output_name] = TensorMap( + return_dict[output_name].keys, + [block], + ) + + if not self.training: + # at evaluation, we also introduce the scaler and additive contributions + return_dict = self.scaler(systems, return_dict) + for additive_model in self.additive_models: + outputs_for_additive_model: Dict[str, ModelOutput] = {} + for name, output in outputs.items(): + if name in additive_model.outputs: + outputs_for_additive_model[name] = output + additive_contributions = additive_model( + systems, + outputs_for_additive_model, + selected_atoms, + ) + for name in additive_contributions: + # TODO: uncomment this after metatensor.torch.add + # is updated to handle sparse sums + # return_dict[name] = metatensor.torch.add( + # return_dict[name], + # additive_contributions[name].to( + # device=return_dict[name].device, + # dtype=return_dict[name].dtype + # ), + # ) + # TODO: "manual" sparse sum: update to metatensor.torch.add + # after sparse sum is implemented in metatensor.operations + output_blocks: List[TensorBlock] = [] + for k, b in return_dict[name].items(): + if k in additive_contributions[name].keys: + output_blocks.append( + _add_block_block( + b, + additive_contributions[name] + .block(k) + .to(device=b.device, dtype=b.dtype), + ) + ) + else: + output_blocks.append(b) + return_dict[name] = TensorMap(return_dict[name].keys, output_blocks) + + return return_dict + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + context: Literal["restart", "finetune", "export"], + ) -> "PhACE": + if context == "restart": + logging.info(f"Using latest model from epoch {checkpoint['epoch']}") + model_state_dict = checkpoint["model_state_dict"] + elif context in {"finetune", "export"}: + logging.info(f"Using best model from epoch {checkpoint['best_epoch']}") + model_state_dict = checkpoint["best_model_state_dict"] + else: + raise ValueError("Unknown context tag for checkpoint loading!") + + # Create the model + model_data = checkpoint["model_data"] + model = cls( + hypers=model_data["model_hypers"], + dataset_info=model_data["dataset_info"], + ) + state_dict_iterator = iter(model_state_dict.values()) + next(state_dict_iterator) # skip an int tensor + next(state_dict_iterator) # skip another int tensor + dtype = next(state_dict_iterator).dtype + model.to(dtype).load_state_dict(model_state_dict) + model.additive_models[0].sync_tensor_maps() + model.scaler.sync_tensor_maps() + + # Loading the metadata from the checkpoint + model.metadata = merge_metadata(model.metadata, checkpoint.get("metadata")) + + return model + + def export(self, metadata: Optional[ModelMetadata] = None) -> AtomisticModel: + # Before exporting, we have to + # - set the module to the gradient-free one (torchscript doesn't like grad in + # the functional way they're used in the GradientModel) + # - delete the other models: even if the forward function doesn't use them, + # torchscript will try to compile them anyway + self.module = self.fake_gradient_model + del self.gradient_model + del self.fake_gradient_model + + dtype = next(self.parameters()).dtype + if dtype not in self.__supported_dtypes__: + raise ValueError(f"unsupported dtype {dtype} for PET") + + # Make sure the model is all in the same dtype + # For example, after training, the additive models could still be in + # float64 + self.to(dtype) + + # Additionally, the composition model contains some `TensorMap`s that cannot + # be registered correctly with Pytorch. This function moves them: + self.additive_models[0].weights_to(torch.device("cpu"), torch.float64) + + interaction_ranges = [ + self.hypers["num_message_passing_layers"] * self.hypers["cutoff"] + ] + for additive_model in self.additive_models: + if hasattr(additive_model, "cutoff_radius"): + interaction_ranges.append(additive_model.cutoff_radius) + interaction_range = max(interaction_ranges) + + capabilities = ModelCapabilities( + outputs=self.outputs, + atomic_types=self.atomic_types, + interaction_range=interaction_range, + length_unit=self.dataset_info.length_unit, + supported_devices=self.__supported_devices__, + dtype=dtype_to_str(dtype), + ) + + metadata = merge_metadata(self.metadata, metadata) + + return AtomisticModel(self.eval(), metadata, capabilities) + + def _add_output(self, target_name: str, target_info: TargetInfo) -> None: + self.outputs[target_name] = ModelOutput( + quantity=target_info.quantity, + unit=target_info.unit, + per_atom=True, + ) + + self.key_labels[target_name] = target_info.layout.keys + self.component_labels[target_name] = [ + block.components for block in target_info.layout.blocks() + ] + self.property_labels[target_name] = [ + block.properties for block in target_info.layout.blocks() + ] + + def requested_neighbor_lists( + self, + ) -> List[NeighborListOptions]: + return [ + NeighborListOptions( + cutoff=self.cutoff_radius, + full_list=True, + strict=True, + ) + ] + + def get_checkpoint(self) -> Dict: + checkpoint = { + "architecture_name": "experimental.phace", + "model_ckpt_version": self.__checkpoint_version__, + "metadata": self.metadata, + "model_data": { + "model_hypers": self.hypers, + "dataset_info": self.dataset_info, + }, + "epoch": None, + "best_epoch": None, + "model_state_dict": self.state_dict(), + "best_model_state_dict": self.state_dict(), + } + return checkpoint + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["model_ckpt_version"] == v: + update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["model_ckpt_version"] = v + 1 + + if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using model " + f"version {checkpoint['model_ckpt_version']}, while the current model " + f"version is {cls.__checkpoint_version__}." + ) + + return checkpoint diff --git a/src/metatrain/experimental/phace/modules/__init__.py b/src/metatrain/experimental/phace/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/metatrain/experimental/phace/modules/base_model.py b/src/metatrain/experimental/phace/modules/base_model.py new file mode 100644 index 0000000000..ba9f16a7a7 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/base_model.py @@ -0,0 +1,430 @@ +from typing import Dict, List + +import numpy as np +import torch +from torch.func import functional_call, grad + +from .center_embedding import embed_centers +from .cg import get_cg_coefficients +from .cg_iterator import CGIterator +from .layers import Linear +from .message_passing import EquivariantMessagePasser, InvariantMessagePasser +from .precomputations import Precomputer +from .tensor_product import couple_features_all, uncouple_features_all + + +class BaseModel(torch.nn.Module): + def __init__(self, hypers, dataset_info) -> None: + super().__init__() + self.atomic_types = dataset_info.atomic_types + self.hypers = hypers + + self.nu_max = hypers["max_correlation_order_per_layer"] + self.head_num_layers = hypers["head_num_layers"] + self.spherical_linear_layers = hypers["spherical_linear_layers"] + self.register_buffer("nu_scaling", torch.tensor(hypers["nu_scaling"])) + + # A module that precomputes quantities that are useful in all message-passing + # steps (spherical harmonics, distances) + self.precomputer = Precomputer( + max_eigenvalue=hypers["radial_basis"]["max_eigenvalue"], + cutoff=hypers["cutoff"], + cutoff_width=hypers["cutoff_width"], + scale=hypers["radial_basis"]["scale"], + optimizable_lengthscales=hypers["radial_basis"]["optimizable_lengthscales"], + all_species=self.atomic_types, + use_sphericart=hypers["use_sphericart"], + ) + + # representation sizes + n_max = self.precomputer.n_max_l + self.l_max = len(n_max) - 1 + n_channels = hypers["num_element_channels"] + if hypers["force_rectangular"]: + self.k_max_l = [n_channels * n_max[0]] * (self.l_max + 1) + else: + self.k_max_l = [ + n_channels * n_max[l] + for l in range(self.l_max + 1) # noqa: E741 + ] + + ################ + # Transformation matrices from "coupled" (aka spherical) to uncoupled (used for) + # tensor products basis and back + cg_calculator = get_cg_coefficients(2 * ((self.l_max + 1) // 2)) + self.padded_l_list = [2 * ((l + 1) // 2) for l in range(self.l_max + 1)] # noqa: E741 + U_dict = {} + for padded_l in np.unique(self.padded_l_list): + cg_tensors = [ + cg_calculator._cgs[(padded_l // 2, padded_l // 2, L)] + for L in range(padded_l + 1) + ] + U = torch.concatenate( + [cg_tensor for cg_tensor in cg_tensors], dim=2 + ).reshape((padded_l + 1) ** 2, (padded_l + 1) ** 2) + assert torch.allclose( + U @ U.T, torch.eye((padded_l + 1) ** 2, dtype=U.dtype) + ) + assert torch.allclose( + U.T @ U, torch.eye((padded_l + 1) ** 2, dtype=U.dtype) + ) + U_dict[int(padded_l)] = U + self.U_dict = U_dict + ################ + + self.num_message_passing_layers = hypers["num_message_passing_layers"] + if self.num_message_passing_layers < 1: + raise ValueError("Number of message-passing layers must be at least 1") + + # A buffer that maps atomic types to indices in the embeddings + species_to_species_index = torch.zeros( + (max(self.atomic_types) + 1,), dtype=torch.int + ) + species_to_species_index[self.atomic_types] = torch.arange( + len(self.atomic_types), dtype=torch.int + ) + self.register_buffer("species_to_species_index", species_to_species_index) + + self.embeddings = torch.nn.Embedding(len(self.atomic_types), n_channels) + + # The message passing is invariant for the first layer + self.invariant_message_passer = InvariantMessagePasser( + self.atomic_types, + hypers["mp_scaling"], + hypers["disable_nu_0"], + self.precomputer.n_max_l, + self.k_max_l, + ) + # First CG iterator + self.cg_iterator = CGIterator( + self.k_max_l, self.nu_max - 1, self.spherical_linear_layers + ) + + # Subsequent message-passing layers + equivariant_message_passers: List[EquivariantMessagePasser] = [] + generalized_cg_iterators: List[CGIterator] = [] + for _ in range(self.num_message_passing_layers - 1): + equivariant_message_passer = EquivariantMessagePasser( + self.precomputer.n_max_l, + self.k_max_l, + hypers["mp_scaling"], + self.spherical_linear_layers, + ) + equivariant_message_passers.append(equivariant_message_passer) + generalized_cg_iterator = CGIterator( + self.k_max_l, self.nu_max - 1, self.spherical_linear_layers + ) + generalized_cg_iterators.append(generalized_cg_iterator) + self.equivariant_message_passers = torch.nn.ModuleList( + equivariant_message_passers + ) + self.generalized_cg_iterators = torch.nn.ModuleList(generalized_cg_iterators) + + # Heads and last layers + self.head_types = self.hypers["heads"] + self.heads = torch.nn.ModuleDict() + self.last_layers = torch.nn.ModuleDict() + for target_name, target_info in dataset_info.targets.items(): + self._add_output(target_name, target_info) + + def forward( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, Dict[int, torch.Tensor]]: + """ + Forward pass of the base model. + + :param batch: Dictionary containing batched tensors: + - positions: stacked positions of all atoms [N_total, 3] + - cells: stacked unit cells [N_structures, 3, 3] + - species: atomic types of all atoms [N_total] + - cell_shifts: cell shift vectors for all pairs [N_pairs, 3] + - center_indices: global center indices for all pairs [N_pairs] + - neighbor_indices: global neighbor indices for all pairs [N_pairs] + - structure_pairs: structure index for each pair [N_pairs] + :return: Dictionary of predictions + """ + device = batch["positions"].device + if self.U_dict[0].device != device: + self.U_dict = {key: U.to(device) for key, U in self.U_dict.items()} + dtype = batch["positions"].dtype + if self.U_dict[0].dtype != dtype: + self.U_dict = {key: U.to(dtype) for key, U in self.U_dict.items()} + + n_atoms = batch["positions"].size(0) + + # precomputation of distances and spherical harmonics + spherical_harmonics, radial_basis = self.precomputer( + positions=batch["positions"], + cells=batch["cells"], + cell_shifts=batch["cell_shifts"], + center_indices=batch["center_indices"], + neighbor_indices=batch["neighbor_indices"], + structure_pairs=batch["structure_pairs"], + center_species=batch["species"][batch["center_indices"]], + neighbor_species=batch["species"][batch["neighbor_indices"]], + ) + + # scaling the spherical harmonics in this way makes sure that each successive + # body-order is scaled by the same factor + spherical_harmonics = [sh * self.nu_scaling for sh in spherical_harmonics] + + # calculate the center embeddings; these are shared across all layers for now + center_species_indices = self.species_to_species_index[batch["species"]] + center_embeddings = self.embeddings(center_species_indices) + + initial_features = torch.ones( + (n_atoms, 1, self.k_max_l[0]), + dtype=batch["positions"].dtype, + device=batch["positions"].device, + ) + initial_element_embedding = embed_centers( + [initial_features], center_embeddings + )[0] + # (now they are all the same as the center embeddings) + + # ACE-like features + features = self.invariant_message_passer( + radial_basis, + spherical_harmonics, + batch["center_indices"], + batch["neighbor_indices"], + n_atoms, + initial_element_embedding, + ) + features = uncouple_features_all( # from spherical to TP basis + features, + self.k_max_l, + self.U_dict, + self.l_max, + self.padded_l_list, + ) + features = self.cg_iterator(features, self.U_dict) + + # message passing + for message_passer, generalized_cg_iterator in zip( + self.equivariant_message_passers, + self.generalized_cg_iterators, + strict=False, + ): + embedded_features = embed_centers(features, center_embeddings) + mp_features = message_passer( + radial_basis, + spherical_harmonics, + batch["center_indices"], + batch["neighbor_indices"], + embedded_features, + self.U_dict, + ) + iterated_features = generalized_cg_iterator(mp_features, self.U_dict) + features = iterated_features + + features = couple_features_all( # back to spherical basis + features, + self.U_dict, + self.l_max, + self.padded_l_list, + ) + + # center embedding + features = embed_centers(features, center_embeddings) + + # predictions + return_dict: Dict[str, Dict[int, torch.Tensor]] = {} + return_dict["features"] = {l: tensor for l, tensor in enumerate(features)} # noqa: E741 + + last_layer_feature_dict: Dict[str, List[torch.Tensor]] = {} + for output_name, layer in self.heads.items(): + last_layer_features = features + last_layer_features[0] = layer(last_layer_features[0]) # only L=0 + last_layer_feature_dict[output_name] = last_layer_features + + for output_name, layer in self.last_layers.items(): + output: Dict[int, torch.Tensor] = {} + for l_str, layer_L in layer.items(): + l = int(l_str) # noqa: E741 + output[l] = layer_L(last_layer_feature_dict[output_name][l]) + return_dict[output_name] = output + + for output_name, llf in last_layer_feature_dict.items(): + return_dict[f"{output_name}__llf"] = {l: t for l, t in enumerate(llf)} # noqa: E741 + + return return_dict + + def _add_output(self, target_name, target_info): + if target_name not in self.head_types: + if target_info.is_scalar: + use_mlp = True # default to MLP for scalars + else: + use_mlp = False # can't use MLP for equivariants + # TODO: the equivariant, or part of it, could be a scalar... + else: + # specified by the user + use_mlp = self.head_types[target_name] == "mlp" + + if use_mlp: + if target_info.is_spherical or target_info.is_cartesian: + raise ValueError("MLP heads are only supported for scalar targets.") + + layers = ( + [Linear(self.k_max_l[0], self.k_max_l[0]), torch.nn.SiLU()] + if self.head_num_layers == 1 + else [Linear(self.k_max_l[0], 4 * self.k_max_l[0]), torch.nn.SiLU()] + + [Linear(4 * self.k_max_l[0], 4 * self.k_max_l[0]), torch.nn.SiLU()] + * (self.head_num_layers - 2) + + [Linear(4 * self.k_max_l[0], self.k_max_l[0]), torch.nn.SiLU()] + ) + self.heads[target_name] = torch.nn.Sequential(*layers) + else: + self.heads[target_name] = torch.nn.Identity() + + if target_info.is_scalar: + self.last_layers[target_name] = torch.nn.ModuleDict( + { + "0": Linear( + self.k_max_l[0], len(target_info.layout.block().properties) + ) + } + ) + elif target_info.is_cartesian: + # here, we handle Cartesian targets + # we just treat them as a spherical L=1 targets, the conversion will be + # performed in the metatensor wrapper + if len(target_info.layout.block().components) == 1: + self.last_layers[target_name] = torch.nn.ModuleDict( + { + "1": Linear( + self.k_max_l[1], len(target_info.layout.block().properties) + ) + } + ) + else: + raise NotImplementedError( + "PhACE only supports Cartesian targets with rank=1." + ) + else: # spherical equivariant + irreps = [] + for key in target_info.layout.keys: + key_values = key.values + l = int(key_values[0]) # noqa: E741 + # s = int(key_values[1]) is ignored here + irreps.append(l) + # provide good error if the basis is not big enough + if l > self.l_max: + raise ValueError( + f"Target {target_name} requires l={l}, but the model's basis " + f"only goes up to l={self.l_max}. You should increase the " + "``max_eigenvalue`` hyperparameter." + ) + self.last_layers[target_name] = torch.nn.ModuleDict( + { + str(l): Linear( + self.k_max_l[l], + len(target_info.layout.block({"o3_lambda": l}).properties), + ) + for l in irreps # noqa: E741 + } + ) + + +class GradientModel(torch.nn.Module): + """ + Wrapper around BaseModel that computes gradients with respect to positions and + strain. + """ + + def __init__(self, module) -> None: + super().__init__() + self.module = module + + def forward( + self, + batch: Dict[str, torch.Tensor], + outputs_to_take_gradients_of: List[str], + ): + if len(outputs_to_take_gradients_of) == 0: + return self.module(batch) + + n_structures = batch["n_atoms"].size(0) + device = batch["positions"].device + dtype = batch["positions"].dtype + + def compute_energy(params, buffers, positions, strains, output_name): + # Apply strain to positions and cells + # For each atom, get the strain matrix for its structure by indexing with + # structure_centers + # strains: [n_structures, 3, 3] + # structure_centers: [n_atoms] - maps each atom to its structure index + # positions: [n_atoms, 3] + + # Get the strain matrix for each atom: [n_atoms, 3, 3] + atom_strains = strains[batch["structure_centers"]] + + # Apply strain to positions: pos @ strain for each atom (using einsum) + strained_positions = torch.einsum("ij,ijk->ik", positions, atom_strains) + + # Apply strain to cells: [n_structures, 3, 3] @ [n_structures, 3, 3] + strained_cells = torch.bmm(batch["cells"], strains) + + # Create a modified batch with strained positions and cells + strained_batch = { + "positions": strained_positions, + "cells": strained_cells, + "species": batch["species"], + "cell_shifts": batch["cell_shifts"], + "center_indices": batch["center_indices"], + "neighbor_indices": batch["neighbor_indices"], + "structure_pairs": batch["structure_pairs"], + } + + predictions = functional_call( + self.module, (params, buffers), (strained_batch,) + ) + return predictions[output_name][0].sum(), predictions + + compute_val_and_grad = grad(compute_energy, argnums=(2, 3), has_aux=True) + + params = dict(self.module.named_parameters()) + buffers = dict(self.module.named_buffers()) + + # Create strain tensors (one 3x3 identity per structure) + strains = ( + torch.eye(3, device=device, dtype=dtype) + .unsqueeze(0) + .expand(n_structures, -1, -1) + .clone() + ) # [n_structures, 3, 3] + + all_gradients = {} + for output_name in outputs_to_take_gradients_of: + (pos_grad, strain_grads), predictions = compute_val_and_grad( + params, buffers, batch["positions"], strains, output_name + ) + all_gradients[f"{output_name}__pos"] = { + -1: pos_grad # Forces are negative gradient of energy + } + all_gradients[f"{output_name}__str"] = { + -1: strain_grads # Virial/stress from strain gradient + } + + predictions.update(all_gradients) + return predictions + + +class FakeGradientModel(torch.nn.Module): + """ + Wrapper around BaseModel that does not compute gradients. + + Used during inference when returning gradients from inside the model is not needed + and torchscript compatibility is required. + """ + + def __init__(self, module) -> None: + super().__init__() + self.module = module + + def forward( + self, + batch: Dict[str, torch.Tensor], + outputs_to_take_gradients_of: List[str], + ): + return self.module(batch) diff --git a/src/metatrain/experimental/phace/modules/center_embedding.py b/src/metatrain/experimental/phace/modules/center_embedding.py new file mode 100644 index 0000000000..354b8cccb2 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/center_embedding.py @@ -0,0 +1,19 @@ +from typing import List + +import torch + + +def embed_centers(features: List[torch.Tensor], center_embeddings: torch.Tensor): + # multiplies arbitrary equivariant features by the provided center embeddings + # (the center embeddings are expanded as needed) + n_channels = center_embeddings.shape[-1] + new_features: List[torch.Tensor] = [] + for feature_tensor in features: + assert feature_tensor.shape[-1] % n_channels == 0 + n_repeats = feature_tensor.shape[-1] // n_channels + center_embeddings_broadcast = center_embeddings.repeat(1, n_repeats) + for _ in range(len(feature_tensor.shape) - len(center_embeddings.shape)): + center_embeddings_broadcast = center_embeddings_broadcast.unsqueeze(1) + new_block_values = feature_tensor * center_embeddings_broadcast + new_features.append((new_block_values)) + return new_features diff --git a/src/metatrain/experimental/phace/modules/cg.py b/src/metatrain/experimental/phace/modules/cg.py new file mode 100644 index 0000000000..0924772d9a --- /dev/null +++ b/src/metatrain/experimental/phace/modules/cg.py @@ -0,0 +1,108 @@ +import numpy as np +import torch +import wigners + + +def get_cg_coefficients(l_max): + cg_object = ClebschGordanReal() + for l1 in range(l_max + 1): + for l2 in range(l_max + 1): + for L in range(abs(l1 - l2), min(l1 + l2, l_max) + 1): + cg_object._add(l1, l2, L) + return cg_object + + +class ClebschGordanReal: + def __init__(self): + self._cgs = {} + + def _add(self, l1, l2, L): + # print(f"Adding new CGs with l1={l1}, l2={l2}, L={L}") + + if self._cgs is None: + raise ValueError("Trying to add CGs when not initialized... exiting") + + if (l1, l2, L) in self._cgs: + raise ValueError("Trying to add CGs that are already present... exiting") + + maxx = max(l1, max(l2, L)) + + # real-to-complex and complex-to-real transformations as matrices + r2c = {} + c2r = {} + for l in range(0, maxx + 1): # noqa: E741 + r2c[l] = _real2complex(l) + c2r[l] = np.conjugate(r2c[l]).T + + complex_cg = _complex_clebsch_gordan_matrix(l1, l2, L) + + real_cg = (r2c[l1].T @ complex_cg.reshape(2 * l1 + 1, -1)).reshape( + complex_cg.shape + ) + + real_cg = real_cg.swapaxes(0, 1) + real_cg = (r2c[l2].T @ real_cg.reshape(2 * l2 + 1, -1)).reshape(real_cg.shape) + real_cg = real_cg.swapaxes(0, 1) + + real_cg = real_cg @ c2r[L].T + + if (l1 + l2 + L) % 2 == 0: + rcg = np.real(real_cg) + else: + rcg = np.imag(real_cg) + + # Zero any possible (and very rare) near-zero elements + where_almost_zero = np.where( + np.logical_and(np.abs(rcg) > 0, np.abs(rcg) < 1e-14) + ) + if len(where_almost_zero[0] != 0): + print("INFO: Found almost-zero CG!") + for i0, i1, i2 in zip( + where_almost_zero[0], + where_almost_zero[1], + where_almost_zero[2], + strict=False, + ): + rcg[i0, i1, i2] = 0.0 + + self._cgs[(l1, l2, L)] = torch.tensor(rcg) + + def get(self, key): + if key in self._cgs: + return self._cgs[key] + else: + self._add(key[0], key[1], key[2]) + return self._cgs[key] + + +def _real2complex(L): + """ + Computes a matrix that can be used to convert from real to complex-valued + spherical harmonics(coefficients) of order L. + + It's meant to be applied to the left, ``real2complex @ [-L..L]``. + """ + result = np.zeros((2 * L + 1, 2 * L + 1), dtype=np.complex128) + + I_SQRT_2 = 1.0 / np.sqrt(2) + + for m in range(-L, L + 1): + if m < 0: + result[L - m, L + m] = I_SQRT_2 * 1j * (-1) ** m + result[L + m, L + m] = -I_SQRT_2 * 1j + + if m == 0: + result[L, L] = 1.0 + + if m > 0: + result[L + m, L + m] = I_SQRT_2 * (-1) ** m + result[L - m, L + m] = I_SQRT_2 + + return result + + +def _complex_clebsch_gordan_matrix(l1, l2, L): + if np.abs(l1 - l2) > L or np.abs(l1 + l2) < L: + return np.zeros((2 * l1 + 1, 2 * l2 + 1, 2 * L + 1), dtype=np.double) + else: + return wigners.clebsch_gordan_array(l1, l2, L) diff --git a/src/metatrain/experimental/phace/modules/cg_iterator.py b/src/metatrain/experimental/phace/modules/cg_iterator.py new file mode 100644 index 0000000000..3fd3f1cca6 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/cg_iterator.py @@ -0,0 +1,65 @@ +from typing import Dict, List + +import torch + +from .layers import LinearList as Linear +from .tensor_product import tensor_product + + +class CGIterator(torch.nn.Module): + # A high-level CG iterator, doing multiple iterations + def __init__(self, k_max_l, number_of_iterations, spherical_linear_layers): + super().__init__() + self.number_of_iterations = number_of_iterations + + # equivariant linear mixers (to be used at the beginning) + mixers = [] + for _ in range(self.number_of_iterations + 1): + mixers.append(Linear(k_max_l, spherical_linear_layers)) + self.mixers = torch.nn.ModuleList(mixers) + + # CG iterations + cg_iterations = [] + for _ in range(self.number_of_iterations): + cg_iterations.append(CGIteration(k_max_l, spherical_linear_layers)) + self.cg_iterations = torch.nn.ModuleList(cg_iterations) + + def forward( + self, features: List[torch.Tensor], U_dict: Dict[int, torch.Tensor] + ) -> List[torch.Tensor]: + density = features + mixed_densities = [mixer(density, U_dict) for mixer in self.mixers] + + starting_density = mixed_densities[0] + density_index = 1 + current_density = starting_density + for iterator in self.cg_iterations: + current_density = iterator( + current_density, mixed_densities[density_index], U_dict + ) + density_index += 1 + + return current_density + + +class CGIteration(torch.nn.Module): + # A single Clebsch-Gordan iteration, including: + # - tensor product + # - linear transformation + # - skip connection + def __init__(self, k_max_l, spherical_linear_layers): + super().__init__() + self.linear = Linear(k_max_l, spherical_linear_layers) + + def forward( + self, + features_1: List[torch.Tensor], + features_2: List[torch.Tensor], + U_dict: Dict[int, torch.Tensor], + ) -> List[torch.Tensor]: + features_out = tensor_product(features_1, features_2) + features_out = self.linear(features_out, U_dict) + features_out = [ + f1 + fo for f1, fo in zip(features_1, features_out, strict=True) + ] + return features_out diff --git a/src/metatrain/experimental/phace/modules/layers.py b/src/metatrain/experimental/phace/modules/layers.py new file mode 100644 index 0000000000..0e377c55b7 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/layers.py @@ -0,0 +1,64 @@ +from typing import Dict, List + +import torch + +from .tensor_product import couple_features_all, uncouple_features_all + + +class Linear(torch.nn.Module): + # NTK-style linear layer (neural tangent kernel) + + def __init__(self, n_feat_in, n_feat_out): + super().__init__() + self.linear_layer = torch.nn.Linear(n_feat_in, n_feat_out, bias=False) + self.linear_layer.weight.data.normal_(0.0, 1.0) + self.n_feat_in = n_feat_in if n_feat_in > 0 else 1 + + def forward(self, x): + return self.linear_layer(x) * self.n_feat_in ** (-0.5) + + +class LinearList(torch.nn.Module): + # list of linear layers for equivariant features, either in the spherical basis + # (spherical_linear_layers=True) or in the coupled (TP) basis + + def __init__(self, k_max_l: List[int], spherical_linear_layers) -> None: + super().__init__() + self.spherical_linear_layers = spherical_linear_layers + self.k_max_l = k_max_l + self.l_max = len(k_max_l) - 1 + self.padded_l_list = [2 * ((l + 1) // 2) for l in range(self.l_max + 1)] # noqa: E741 + if spherical_linear_layers: + self.linears = torch.nn.ModuleList( + [Linear(k_max, k_max) for k_max in k_max_l] + ) + else: + l_max = len(k_max_l) - 1 + self.linears = [] + for l in range(l_max, -1, -1): # noqa: E741 + lower_bound = k_max_l[l + 1] if l < l_max else 0 + upper_bound = k_max_l[l] + dimension = upper_bound - lower_bound + self.linears.append(Linear(dimension, dimension)) + self.linears = torch.nn.ModuleList(self.linears[::-1]) + + def forward( + self, features_list: List[torch.Tensor], U_dict: Dict[int, torch.Tensor] + ) -> List[torch.Tensor]: + if self.spherical_linear_layers: + features_list = couple_features_all( + features_list, U_dict, self.l_max, self.padded_l_list + ) + + new_features_list: List[torch.Tensor] = [] + for i, linear in enumerate(self.linears): + current_features = features_list[i] + new_features = linear(current_features) + new_features_list.append(new_features) + + if self.spherical_linear_layers: + new_features_list = uncouple_features_all( + new_features_list, self.k_max_l, U_dict, self.l_max, self.padded_l_list + ) + + return new_features_list diff --git a/src/metatrain/experimental/phace/modules/message_passing.py b/src/metatrain/experimental/phace/modules/message_passing.py new file mode 100644 index 0000000000..62a1a63a6c --- /dev/null +++ b/src/metatrain/experimental/phace/modules/message_passing.py @@ -0,0 +1,138 @@ +from typing import Dict, List + +import torch + +from .layers import LinearList as Linear +from .radial_mlp import MLPRadialBasis +from .tensor_product import ( + tensor_product, + uncouple_features_all, +) + + +class InvariantMessagePasser(torch.nn.Module): + # performs invariant message passing with linear contractions + def __init__( + self, all_species: List[int], mp_scaling, disable_nu_0, n_max_l, k_max_l + ) -> None: + super().__init__() + + self.all_species = all_species + self.radial_basis_mlp = MLPRadialBasis(n_max_l, k_max_l) + self.n_max_l = n_max_l + self.k_max_l = k_max_l + self.l_max = len(self.n_max_l) - 1 + self.irreps_out = [(l, 1) for l in range(self.l_max + 1)] # noqa: E741 + + # Register mp_scaling as a buffer for efficiency + self.register_buffer("mp_scaling", torch.tensor(mp_scaling)) + self.disable_nu_0 = disable_nu_0 + + def forward( + self, + radial_basis: List[torch.Tensor], + spherical_harmonics: List[torch.Tensor], + centers, + neighbors, + n_atoms: int, + initial_center_embedding, + ) -> List[torch.Tensor]: + radial_basis = self.radial_basis_mlp(radial_basis) + + density = [] + for l in range(self.l_max + 1): # noqa: E741 + spherical_harmonics_l = spherical_harmonics[l] + radial_basis_l = radial_basis[l] + density_l = torch.zeros( + (n_atoms, spherical_harmonics_l.shape[1], radial_basis_l.shape[1]), + device=radial_basis_l.device, + dtype=radial_basis_l.dtype, + ) + density_l.index_add_( + dim=0, + index=centers, + source=spherical_harmonics_l.unsqueeze(2) + * radial_basis_l.unsqueeze(1) + * initial_center_embedding[neighbors][:, :, : radial_basis_l.shape[1]], + ) + density.append(density_l * self.mp_scaling) + + if not self.disable_nu_0: + density[0] = density[0] + initial_center_embedding + + return density + + +class EquivariantMessagePasser(torch.nn.Module): + # performs equivariant message passing with linear contractions + def __init__( + self, + n_max_l, + k_max_l, + mp_scaling, + spherical_linear_layers, + ) -> None: + super().__init__() + + self.n_max_l = list(n_max_l) + self.k_max_l = k_max_l + self.l_max = len(self.n_max_l) - 1 + + # Register mp_scaling as a buffer for efficiency + self.register_buffer("mp_scaling", torch.tensor(mp_scaling)) + self.padded_l_list = [2 * ((l + 1) // 2) for l in range(self.l_max + 1)] # noqa: E741 + + self.linear = Linear(self.k_max_l, spherical_linear_layers) + + self.radial_basis_mlp = MLPRadialBasis(n_max_l, k_max_l) + + def forward( + self, + radial_basis: List[torch.Tensor], + spherical_harmonics: List[torch.Tensor], + centers, + neighbors, + features: List[torch.Tensor], + U_dict: Dict[int, torch.Tensor], + ) -> List[torch.Tensor]: + radial_basis = self.radial_basis_mlp(radial_basis) + vector_expansion = [ + spherical_harmonics[l].unsqueeze(2) * radial_basis[l].unsqueeze(1) + for l in range(self.l_max + 1) # noqa: E741 + ] + + uncoupled_vector_expansion = uncouple_features_all( + vector_expansion, self.k_max_l, U_dict, self.l_max, self.padded_l_list + ) + + n_atoms = features[0].shape[0] + + indexed_features = [] + for feature in features: + indexed_features.append(feature[neighbors]) + + combined_features = tensor_product(uncoupled_vector_expansion, indexed_features) + + combined_features_pooled = [] + for f in combined_features: + combined_features_pooled.append( + torch.zeros( + (n_atoms,) + f.shape[1:], + device=f.device, + dtype=f.dtype, + ), + ) + combined_features_pooled[-1].index_add_( + dim=0, + index=centers, + source=f, + ) + + # apply mp_scaling + combined_features_pooled = [ + (f * self.mp_scaling) for f in combined_features_pooled + ] + + features_out = self.linear(combined_features_pooled, U_dict) + features_out = [f + fo for f, fo in zip(features, features_out, strict=False)] + return features_out diff --git a/src/metatrain/experimental/phace/modules/physical_basis.py b/src/metatrain/experimental/phace/modules/physical_basis.py new file mode 100644 index 0000000000..4171367377 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/physical_basis.py @@ -0,0 +1,74 @@ +import copy + +import numpy as np +from physical_basis import PhysicalBasis + +from .splines import generate_splines + + +# This splines the basis functions from the ``physical_basis`` package. +# It also normalizes them. + + +def get_physical_basis_spliner(E_max, r_cut, normalize): + l_max = 50 + n_max = 50 + a = 10.0 # by construction of the files + + physical_basis = PhysicalBasis() + E_ln = physical_basis.E_ln + E_nl = E_ln.T + l_max_new = np.where(E_nl[0, :] <= E_max)[0][-1] + if l_max_new > l_max: + raise ValueError("l_max too large, try decreasing E_max") + else: + l_max = l_max_new + + n_max_l = [] + for l in range(l_max + 1): # noqa: E741 + n_max_l.append(np.where(E_nl[:, l] <= E_max)[0][-1] + 1) + if n_max_l[0] > n_max: + raise ValueError("n_max too large, try decreasing max_eigenvalue") + + def function_for_splining(n, l, x): # noqa: E741 + ret = physical_basis.compute(n, l, x) + if normalize: + # normalize by square root of sphere volume, excluding sqrt(4pi) which is + # included in the SH + ret *= np.sqrt((1 / 3) * r_cut**3) + return ret + + def function_for_splining_derivative(n, l, x): # noqa: E741 + ret = physical_basis.compute_derivative(n, l, x) + if normalize: + # normalize by square root of sphere volume, excluding sqrt(4pi) which is + # included in the SH + ret *= np.sqrt((1 / 3) * r_cut**3) + return ret + + def index_to_nl(index, n_max_l): + # FIXME: should probably use cumsum + n = copy.deepcopy(index) + for l in range(l_max + 1): # noqa: E741 + n -= n_max_l[l] + if n < 0: + break + return n + n_max_l[l], l + + def function_for_splining_index(index, r): + n, l = index_to_nl(index, n_max_l) # noqa: E741 + return function_for_splining(n, l, r) + + def function_for_splining_index_derivative(index, r): + n, l = index_to_nl(index, n_max_l) # noqa: E741 + return function_for_splining_derivative(n, l, r) + + spliner = generate_splines( + function_for_splining_index, + function_for_splining_index_derivative, + np.sum(n_max_l), + a, + ) + + n_max_l = [int(n_max) for n_max in n_max_l] + return n_max_l, spliner diff --git a/src/metatrain/experimental/phace/modules/precomputations.py b/src/metatrain/experimental/phace/modules/precomputations.py new file mode 100644 index 0000000000..ea74416a33 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/precomputations.py @@ -0,0 +1,218 @@ +try: + import sphericart.torch +except ImportError: + pass +from math import factorial + +import numpy as np +import torch +from ase.data import covalent_radii + +from .physical_basis import get_physical_basis_spliner + + +class SphericalHarmonicsNoSphericart(torch.nn.Module): + # uses the sphericart algorithm in pytorch + def __init__(self, l_max): + super(SphericalHarmonicsNoSphericart, self).__init__() + self.l_max = l_max + + self.register_buffer( + "F", torch.empty(((self.l_max + 1) * (self.l_max + 2) // 2,)) + ) + for l in range(l_max + 1): # noqa: E741 + for m in range(0, l + 1): + self.F[l * (l + 1) // 2 + m] = (-1) ** m * np.sqrt( + (2 * l + 1) / (2 * np.pi) * factorial(l - m) / factorial(l + m) + ) + + def forward(self, xyz): + device = xyz.device + dtype = xyz.dtype + + rsq = torch.sum(xyz**2, dim=1) + xyz = xyz / torch.sqrt(rsq).unsqueeze(1) + + x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2] + Q = torch.empty( + (xyz.shape[0], (self.l_max + 1) * (self.l_max + 2) // 2), + device=device, + dtype=dtype, + ) + Q[:, 0] = 1.0 + for l in range(1, self.l_max + 1): # noqa: E741 + Q[:, (l + 1) * (l + 2) // 2 - 1] = ( + -(2 * l - 1) * Q[:, l * (l + 1) // 2 - 1].clone() + ) + Q[:, (l + 1) * (l + 2) // 2 - 2] = ( + -z * Q[:, (l + 1) * (l + 2) // 2 - 1].clone() + ) + for m in range(0, l - 1): + Q[:, l * (l + 1) // 2 + m] = ( + (2 * l - 1) * z * Q[:, (l - 1) * l // 2 + m].clone() + - (l + m - 1) * Q[:, (l - 2) * (l - 1) // 2 + m].clone() + ) / (l - m) + + s = torch.empty((xyz.shape[0], self.l_max + 1), device=device, dtype=dtype) + c = torch.empty((xyz.shape[0], self.l_max + 1), device=device, dtype=dtype) + + s[:, 0] = 0.0 + c[:, 0] = 1.0 + for m in range(1, self.l_max + 1): + s[:, m] = x * s[:, m - 1].clone() + y * c[:, m - 1].clone() + c[:, m] = x * c[:, m - 1].clone() - y * s[:, m - 1].clone() + + Y = torch.empty( + (xyz.shape[0], (self.l_max + 1) * (self.l_max + 1)), + device=device, + dtype=dtype, + ) + for l in range(self.l_max + 1): # noqa: E741 + for m in range(-l, 0): + Y[:, l * l + l + m] = ( + self.F[l * (l + 1) // 2 - m] * Q[:, l * (l + 1) // 2 - m] * s[:, -m] + ) + Y[:, l * l + l] = ( + self.F[l * (l + 1) // 2] + * Q[:, l * (l + 1) // 2] + / torch.sqrt(torch.tensor(2.0, device=device, dtype=dtype)) + ) + for m in range(1, l + 1): + Y[:, l * l + l + m] = ( + self.F[l * (l + 1) // 2 + m] * Q[:, l * (l + 1) // 2 + m] * c[:, m] + ) + + return Y + + +class SphericalHarmonicsSphericart(torch.nn.Module): + def __init__(self, l_max): + super(SphericalHarmonicsSphericart, self).__init__() + self.spherical_harmonics_calculator = sphericart.torch.SphericalHarmonics( + l_max, normalized=True + ) + + def forward(self, xyz): + return self.spherical_harmonics_calculator.compute(xyz) + + +class Precomputer(torch.nn.Module): + def __init__( + self, + max_eigenvalue, + cutoff, + cutoff_width, + scale, + optimizable_lengthscales, + all_species, + use_sphericart, + ): + super().__init__() + + self.n_max_l, self.spliner = get_physical_basis_spliner( + max_eigenvalue, cutoff, normalize=True + ) + self.l_max = len(self.n_max_l) - 1 + + self.spherical_harmonics_split_list = [ + (2 * l + 1) + for l in range(self.l_max + 1) # noqa: E741 + ] + if use_sphericart: + self.spherical_harmonics = SphericalHarmonicsSphericart(self.l_max) + else: + self.spherical_harmonics = SphericalHarmonicsNoSphericart(self.l_max) + + lengthscales = torch.zeros((max(all_species) + 1)) + for species in all_species: + lengthscales[species] = np.log(scale * covalent_radii[species]) + + if optimizable_lengthscales: + self.lengthscales = torch.nn.Parameter(lengthscales) + else: + self.register_buffer("lengthscales", lengthscales) + + self.r_cut = float(cutoff) + self.cutoff_width = float(cutoff_width) + + def forward( + self, + positions, + cells, + cell_shifts, + center_indices, + neighbor_indices, + structure_pairs, + center_species, + neighbor_species, + ): + cartesian_vectors = get_cartesian_vectors( + positions, + cells, + cell_shifts, + center_indices, + neighbor_indices, + structure_pairs, + ) + + r = torch.sqrt((cartesian_vectors**2).sum(dim=-1)) + + spherical_harmonics = self.spherical_harmonics( + cartesian_vectors + ) # Get the spherical harmonics + spherical_harmonics = spherical_harmonics * (4.0 * torch.pi) ** ( + 0.5 + ) # normalize them + spherical_harmonics = torch.split( + spherical_harmonics, self.spherical_harmonics_split_list, dim=1 + ) # Split them into l chunks + + x = r / ( + 0.1 + + torch.exp(self.lengthscales[center_species]) + + torch.exp(self.lengthscales[neighbor_species]) + ) + + capped_x = torch.where(x < 10.0, x, 5.0) + radial_functions = torch.where( + x.unsqueeze(1) < 10.0, self.spliner.compute(capped_x), 0.0 + ) + + cutoff_multiplier = cutoff_fn(r, self.r_cut, self.cutoff_width) + radial_functions = radial_functions * cutoff_multiplier.unsqueeze(1) + + radial_basis = torch.split(radial_functions, self.n_max_l, dim=1) + + return spherical_harmonics, radial_basis + + +def get_cartesian_vectors( + positions, cells, cell_shifts, center_indices, neighbor_indices, structure_pairs +): + """ + Calculate direction vectors between center and neighbor atoms. + + :param positions: Atomic positions [N_total, 3] + :param cells: Unit cells [N_structures, 3, 3] + :param cell_shifts: Cell shift vectors [N_pairs, 3] + :param center_indices: Global center indices [N_pairs] + :param neighbor_indices: Global neighbor indices [N_pairs] + :param structure_pairs: Structure index for each pair [N_pairs] + :return: Direction vectors from center to neighbor [N_pairs, 3] + """ + direction_vectors = ( + positions[neighbor_indices] + - positions[center_indices] + + torch.einsum( + "ab, abc -> ac", cell_shifts.to(cells.dtype), cells[structure_pairs] + ) + ) + return direction_vectors + + +def cutoff_fn(r, r_cut: float, cutoff_width: float): + return torch.where( + r < r_cut - cutoff_width, + 1.0, + 1.0 + 1.0 * torch.cos((r - (r_cut - cutoff_width)) * torch.pi / cutoff_width), + ) diff --git a/src/metatrain/experimental/phace/modules/radial_mlp.py b/src/metatrain/experimental/phace/modules/radial_mlp.py new file mode 100644 index 0000000000..fc7e9b8d4e --- /dev/null +++ b/src/metatrain/experimental/phace/modules/radial_mlp.py @@ -0,0 +1,44 @@ +from typing import List + +import torch + +from .layers import Linear + + +class MLPRadialBasis(torch.nn.Module): + def __init__(self, n_max_l, k_max_l) -> None: + super().__init__() + + l_max = len(n_max_l) - 1 + self.radial_mlps = torch.nn.ModuleDict( + { + str(l): torch.nn.Sequential( + Linear(n_max_l[l], 4 * k_max_l[l]), + torch.nn.SiLU(), + Linear( + 4 * k_max_l[l], + 4 * k_max_l[l], + ), + torch.nn.SiLU(), + Linear( + 4 * k_max_l[l], + 4 * k_max_l[l], + ), + torch.nn.SiLU(), + Linear( + 4 * k_max_l[l], + k_max_l[l], + ), + ) + for l in range(l_max + 1) # noqa: E741 + } + ) + + def forward(self, radial_basis: List[torch.Tensor]) -> List[torch.Tensor]: + radial_basis_after_mlp = [] + for l_string, radial_mlp_l in self.radial_mlps.items(): + l = int(l_string) # noqa: E741 + radial_basis_after_mlp.append(radial_mlp_l(radial_basis[l])) + radial_basis = radial_basis_after_mlp + + return radial_basis diff --git a/src/metatrain/experimental/phace/modules/splines.py b/src/metatrain/experimental/phace/modules/splines.py new file mode 100644 index 0000000000..0ab8e270b5 --- /dev/null +++ b/src/metatrain/experimental/phace/modules/splines.py @@ -0,0 +1,167 @@ +import numpy as np +import torch + + +def generate_splines( + radial_basis, + radial_basis_derivatives, + max_index, + cutoff_radius, + requested_accuracy=1e-8, +): + """Spline generator for tabulated radial integrals. + + Besides some self-explanatory parameters, this function takes as inputs two + functions, namely radial_basis and radial_basis_derivatives. These must be + able to calculate the radial basis functions by taking n, l, and r as their + inputs, where n and l are integers and r is a numpy 1-D array that contains + the spline points at which the radial basis function (or its derivative) + needs to be evaluated. These functions should return a numpy 1-D array + containing the values of the radial basis function (or its derivative) + corresponding to the specified n and l, and evaluated at all points in the + r array. If specified, n_spline_points determines how many spline points + will be used for each splined radial basis function. Alternatively, the user + can specify a requested accuracy. Spline points will be added until either + the relative error or the absolute error fall below the requested accuracy on + average across all radial basis functions. + """ + + def value_evaluator_2D(positions): + values = [] + for index in range(max_index): + value = radial_basis(index, positions.numpy()) + values.append(value) + values = torch.tensor(np.array(values)) + values = values.T + values = values.reshape(len(positions), max_index) + return values + + def derivative_evaluator_2D(positions): + derivatives = [] + for index in range(max_index): + derivative = radial_basis_derivatives(index, positions.numpy()) + derivatives.append(derivative) + derivatives = torch.tensor(np.array(derivatives)) + derivatives = derivatives.T + derivatives = derivatives.reshape(len(positions), max_index) + return derivatives + + dynamic_spliner = DynamicSpliner( + 0.0, + cutoff_radius, + value_evaluator_2D, + derivative_evaluator_2D, + requested_accuracy, + ) + return dynamic_spliner + + +class DynamicSpliner(torch.nn.Module): + def __init__( + self, start, stop, values_fn, derivatives_fn, requested_accuracy + ) -> None: + super().__init__() + + self.start = start + self.stop = stop + + # initialize spline with 11 points; the spline calculation + # is performed in double precision + positions = torch.linspace(start, stop, 11, dtype=torch.float64) + self.register_buffer("spline_positions", positions) + self.register_buffer("spline_values", values_fn(positions)) + self.register_buffer("spline_derivatives", derivatives_fn(positions)) + + self.number_of_custom_dimensions = ( + len(self.spline_values.shape) - 1 # type: ignore + ) + + while True: + n_intermediate_positions = len(self.spline_positions) - 1 # type: ignore + + if n_intermediate_positions >= 50000: + raise ValueError( + "Maximum number of spline points reached. \ + There might be a problem with the functions to be splined" + ) + + half_step = ( + self.spline_positions[1] - self.spline_positions[0] # type: ignore + ) / 2 + intermediate_positions = torch.linspace( + self.start + half_step, + self.stop - half_step, + n_intermediate_positions, + dtype=torch.float64, + ) + + estimated_values = self.compute(intermediate_positions) + new_values = values_fn(intermediate_positions) + + mean_absolute_error = torch.mean(torch.abs(estimated_values - new_values)) + mean_relative_error = torch.mean( + torch.abs((estimated_values - new_values) / new_values) + ) + + if ( + mean_absolute_error < requested_accuracy + or mean_relative_error < requested_accuracy + ): + break + + new_derivatives = derivatives_fn(intermediate_positions) + + concatenated_positions = torch.cat( + [self.spline_positions, intermediate_positions], # type: ignore + dim=0, + ) + concatenated_values = torch.cat( + [self.spline_values, new_values], # type: ignore + dim=0, + ) + concatenated_derivatives = torch.cat( + [self.spline_derivatives, new_derivatives], # type: ignore + dim=0, + ) + + sort_indices = torch.argsort(concatenated_positions, dim=0) + + self.spline_positions = concatenated_positions[sort_indices] + self.spline_values = concatenated_values[sort_indices] + self.spline_derivatives = concatenated_derivatives[sort_indices] + + self.spline_positions = self.spline_positions.to(torch.get_default_dtype()) + self.spline_values = self.spline_values.to(torch.get_default_dtype()) + self.spline_derivatives = self.spline_derivatives.to(torch.get_default_dtype()) + + def compute(self, positions): + x = positions + delta_x = self.spline_positions[1] - self.spline_positions[0] + n = (torch.floor(x / delta_x)).to(dtype=torch.long) + + t = (x - n * delta_x) / delta_x + t_2 = t**2 + t_3 = t**3 + + h00 = 2.0 * t_3 - 3.0 * t_2 + 1.0 + h10 = t_3 - 2.0 * t_2 + t + h01 = -2.0 * t_3 + 3.0 * t_2 + h11 = t_3 - t_2 + + p_k = torch.index_select(self.spline_values, dim=0, index=n) + p_k_1 = torch.index_select(self.spline_values, dim=0, index=n + 1) + + m_k = torch.index_select(self.spline_derivatives, dim=0, index=n) + m_k_1 = torch.index_select(self.spline_derivatives, dim=0, index=n + 1) + + new_shape = (-1,) + (1,) * self.number_of_custom_dimensions + h00 = h00.reshape(new_shape) + h10 = h10.reshape(new_shape) + h01 = h01.reshape(new_shape) + h11 = h11.reshape(new_shape) + + interpolated_values = ( + h00 * p_k + h10 * delta_x * m_k + h01 * p_k_1 + h11 * delta_x * m_k_1 + ) + + return interpolated_values diff --git a/src/metatrain/experimental/phace/modules/tensor_product.py b/src/metatrain/experimental/phace/modules/tensor_product.py new file mode 100644 index 0000000000..19f72a5e9d --- /dev/null +++ b/src/metatrain/experimental/phace/modules/tensor_product.py @@ -0,0 +1,153 @@ +from typing import Dict, List + +import torch + + +def split_up_features(features: List[torch.Tensor], k_max_l: List[int]): + # splits a ragged list of features into a list of lists of features + l_max = len(k_max_l) - 1 + split_features: List[List[torch.Tensor]] = [] + for l in range(l_max, -1, -1): # noqa: E741 + lower_bound = k_max_l[l + 1] if l < l_max else 0 + upper_bound = k_max_l[l] + split_features = [ + [features[lp][:, :, lower_bound:upper_bound] for lp in range(l + 1)] + ] + split_features + return split_features + + +def uncouple_features( + features: List[torch.Tensor], + U: torch.Tensor, + padded_l_max: int, +): + # spherical (coupled) to TP (uncoupled) basis + # features is a list of [..., 2*l+1, n_features] for l = 0, 1, ..., padded_l_max + # U is dense and [(padded_l_max+1)**2, (padded_l_max+1)**2] + if len(features) < padded_l_max + 1: + features.append( + torch.zeros( + (features[0].shape[0], 2 * padded_l_max + 1, features[0].shape[2]), + dtype=features[0].dtype, + device=features[0].device, + ) + ) + stacked_features = torch.cat(features, dim=1) + stacked_features = stacked_features.swapaxes(0, 1) + uncoupled_features = ( + U + @ stacked_features.reshape( + (padded_l_max + 1) * (padded_l_max + 1), + stacked_features.shape[1] * stacked_features.shape[-1], + ) + ).reshape( + (padded_l_max + 1) * (padded_l_max + 1), + stacked_features.shape[1], + stacked_features.shape[-1], + ) + uncoupled_features = uncoupled_features.swapaxes(0, 1) + uncoupled_features = uncoupled_features.reshape( + uncoupled_features.shape[0], + padded_l_max + 1, + padded_l_max + 1, + uncoupled_features.shape[-1], + ) + return uncoupled_features + + +def tensor_product( + uncoupled_features_1: List[torch.Tensor], + uncoupled_features_2: List[torch.Tensor], +): + # tensor product in the TP (uncoupled) basis + new_uncoupled_features = [] + for u1, u2 in zip(uncoupled_features_1, uncoupled_features_2, strict=True): + new_uncoupled_features.append(torch.einsum("...ijf,...jkf->...ikf", u1, u2)) + return new_uncoupled_features + + +def couple_features( + features: torch.Tensor, + U: torch.Tensor, + padded_l_max: int, +): + # TP (uncoupled) to spherical (coupled) basis + # features is [..., padded_l_max+1, padded_l_max+1, n_features] + # U is dense and [(padded_l_max+1)**2, (padded_l_max+1)**2] + split_sizes = [2 * l + 1 for l in range(padded_l_max + 1)] # noqa: E741 + + features = features.reshape( + features.shape[0], + (padded_l_max + 1) * (padded_l_max + 1), + features.shape[-1], + ) + features = features.swapaxes(0, 1) + features = ( + U.T + @ features.reshape( + (padded_l_max + 1) * (padded_l_max + 1), + features.shape[1] * features.shape[-1], + ) + ).reshape( + (padded_l_max + 1) * (padded_l_max + 1), + features.shape[1], + features.shape[-1], + ) + stacked_features = features.swapaxes(0, 1) + features_coupled = [ + t.contiguous() for t in torch.split(stacked_features, split_sizes, dim=1) + ] + + coupled_features = [] + for l in range(padded_l_max + 1): # noqa: E741 + coupled_features.append(features_coupled[l]) + return coupled_features + + +def uncouple_features_all( + coupled_features: List[torch.Tensor], + k_max_l: List[int], + U_dict: Dict[int, torch.Tensor], + l_max: int, + padded_l_list: List[int], +) -> List[torch.Tensor]: + # spherical (coupled) to TP (uncoupled) basis for a list of ragged spherical + # features (different number of channels per l) + split_features = split_up_features(coupled_features, k_max_l) + uncoupled_features = [] + for l in range(l_max + 1): # noqa: E741 + uncoupled_features.append( + uncouple_features( + split_features[l], + U_dict[padded_l_list[l]], + padded_l_list[l], + ) + ) + return uncoupled_features + + +def couple_features_all( + uncoupled_features: List[torch.Tensor], + U_dict: Dict[int, torch.Tensor], + l_max: int, + padded_l_list: List[int], +) -> List[torch.Tensor]: + # TP (uncoupled) to spherical (coupled) basis for a list of ragged TP features + # (different number of channels per l) + coupled_features: List[List[torch.Tensor]] = [] + for l in range(l_max + 1): # noqa: E741 + coupled_features.append( + couple_features( + uncoupled_features[l], + U_dict[padded_l_list[l]], + padded_l_list[l], + ) + ) + concat_coupled_features = [] + for l in range(l_max + 1): # noqa: E741 + concat_coupled_features.append( + torch.concatenate( + [coupled_features[lp][l] for lp in range(l, l_max + 1)], dim=-1 + ) + ) + return concat_coupled_features diff --git a/src/metatrain/experimental/phace/tests/__init__.py b/src/metatrain/experimental/phace/tests/__init__.py new file mode 100644 index 0000000000..b74c9a942a --- /dev/null +++ b/src/metatrain/experimental/phace/tests/__init__.py @@ -0,0 +1,12 @@ +from pathlib import Path + +from metatrain.utils.architectures import get_default_hypers + + +DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz") +DATASET_PATH_PERIODIC = str( + Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz" +) + +DEFAULT_HYPERS = get_default_hypers("experimental.phace") +MODEL_HYPERS = DEFAULT_HYPERS["model"] diff --git a/src/metatrain/experimental/phace/tests/checkpoints/model-v1_trainer-v1.ckpt.gz b/src/metatrain/experimental/phace/tests/checkpoints/model-v1_trainer-v1.ckpt.gz new file mode 100644 index 0000000000..22a75ba175 Binary files /dev/null and b/src/metatrain/experimental/phace/tests/checkpoints/model-v1_trainer-v1.ckpt.gz differ diff --git a/src/metatrain/experimental/phace/tests/test_basic.py b/src/metatrain/experimental/phace/tests/test_basic.py new file mode 100644 index 0000000000..4857a70af4 --- /dev/null +++ b/src/metatrain/experimental/phace/tests/test_basic.py @@ -0,0 +1,195 @@ +import copy +from typing import Any + +import pytest +import torch +from metatomic.torch import System +from omegaconf import OmegaConf + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.architectures import get_default_hypers +from metatrain.utils.data import DatasetInfo +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.loss import LossSpecification +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists +from metatrain.utils.testing import ( + ArchitectureTests, + CheckpointTests, + OutputTests, + TorchscriptTests, +) + + +class PhACETests(ArchitectureTests): + architecture = "experimental.phace" + + @pytest.fixture(params=[0, 1, 2]) + def o3_lambda(self, request: pytest.FixtureRequest) -> int: + return request.param + + @pytest.fixture + def minimal_model_hypers(self): + hypers = get_default_hypers(self.architecture)["model"] + hypers = copy.deepcopy(hypers) + hypers["num_element_channels"] = 4 + return hypers + + +class TestOutput(OutputTests, PhACETests): + is_equivariant_reflections = False + + @pytest.fixture + def n_last_layer_features(self) -> int: + return 256 + + +class TestTorchscript(TorchscriptTests, PhACETests): + float_hypers = [ + "cutoff", + "cutoff_width", + "nu_scaling", + "mp_scaling", + "overall_scaling", + "radial_basis.max_eigenvalue", + ] + + def test_torchscript( + self, model_hypers: dict, dataset_info: DatasetInfo, dtype: Any + ) -> None: + model = self.model_cls(model_hypers, dataset_info) + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists( + system, model.requested_neighbor_lists() + ) + + model.module = model.fake_gradient_model + del model.gradient_model + del model.fake_gradient_model + + model = torch.jit.script(model) + model( + [system], + model.outputs, + ) + + def test_torchscript_save_load( + self, tmpdir: Any, model_hypers: dict, dataset_info: DatasetInfo + ) -> None: + """Tests that the model can be jitted, saved and loaded. + + :param tmpdir: Temporary directory where to save the + model. + :param model_hypers: Hyperparameters to initialize the model. + :param dataset_info: Dataset to initialize the model. + """ + + model = self.model_cls(model_hypers, dataset_info) + + model.module = model.fake_gradient_model + del model.gradient_model + del model.fake_gradient_model + + with tmpdir.as_cwd(): + torch.jit.save(torch.jit.script(model), "model.pt") + torch.jit.load("model.pt") + + def test_torchscript_integers( + self, model_hypers: dict, dataset_info: DatasetInfo + ) -> None: + """Tests that the model can be jitted when some float + parameters are instead supplied as integers. + + :param model_hypers: Hyperparameters to initialize the model. + :param dataset_info: Dataset to initialize the model. + """ + + new_hypers = copy.deepcopy(model_hypers) + for hyper in self.float_hypers: + nested_key = hyper.split(".") + sub_dict = new_hypers + for key in nested_key[:-1]: + sub_dict = sub_dict[key] + sub_dict[nested_key[-1]] = int(sub_dict[nested_key[-1]]) + + model = self.model_cls(new_hypers, dataset_info) + + system = System( + types=torch.tensor([6, 1, 8, 7]), + positions=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] + ), + cell=torch.zeros(3, 3), + pbc=torch.tensor([False, False, False]), + ) + system = get_system_with_neighbor_lists( + system, model.requested_neighbor_lists() + ) + + model.module = model.fake_gradient_model + del model.gradient_model + del model.fake_gradient_model + + model = torch.jit.script(model) + model( + [system], + model.outputs, + ) + + def test_torchscript_dtypechange( + self, model_hypers: dict, dataset_info: DatasetInfo, dtype: torch.dtype + ) -> None: + pass + + +class TestCheckpoints(CheckpointTests, PhACETests): + incompatible_trainer_checkpoints = [] + + @pytest.fixture + def model_trainer( + self, + dataset_path: str, + dataset_targets: dict, + minimal_model_hypers: dict, + default_hypers: dict, + ) -> tuple[ModelInterface, TrainerInterface]: + # Load dataset + dataset, targets_info, dataset_info = self.get_dataset( + dataset_targets, dataset_path + ) + + # Initialize model + model = self.model_cls(minimal_model_hypers, dataset_info) + + # Set the training hyperparameters: + # - Just 1 epoch to keep the test fast + # - Default loss for each target + hypers = copy.deepcopy(default_hypers) + hypers["training"]["compile"] = False + hypers["training"]["num_epochs"] = 1 + loss_hypers = OmegaConf.create( + {k: init_with_defaults(LossSpecification) for k in dataset_targets} + ) + loss_hypers = OmegaConf.to_container(loss_hypers, resolve=True) + hypers["training"]["loss"] = loss_hypers + + # Initialize trainer + trainer = self.trainer_cls(hypers["training"]) + + # Train the model. + trainer.train( + model, + dtype=model.__supported_dtypes__[0], + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir="", + ) + + return model, trainer diff --git a/src/metatrain/experimental/phace/tests/test_regression.py b/src/metatrain/experimental/phace/tests/test_regression.py new file mode 100644 index 0000000000..e6bee20da9 --- /dev/null +++ b/src/metatrain/experimental/phace/tests/test_regression.py @@ -0,0 +1,152 @@ +import random + +import numpy as np +import torch +from metatomic.torch import ModelOutput +from omegaconf import OmegaConf + +from metatrain.experimental.phace import PhACE, Trainer +from metatrain.utils.data import Dataset, DatasetInfo +from metatrain.utils.data.readers import read_systems, read_targets +from metatrain.utils.data.target_info import get_energy_target_info +from metatrain.utils.hypers import init_with_defaults +from metatrain.utils.loss import LossSpecification +from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + +from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS + + +# reproducibility +random.seed(0) +np.random.seed(0) +torch.manual_seed(0) + + +def test_regression_init(): + """Perform a regression test on the model at initialization""" + + targets = {} + targets["mtt::U0"] = get_energy_target_info( + "energy", {"quantity": "energy", "unit": "eV"} + ) + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=targets + ) + model = PhACE(MODEL_HYPERS, dataset_info) + + model.module = model.fake_gradient_model + del model.gradient_model + del model.fake_gradient_model + model = torch.jit.script(model) + + # Predict on the first five systems + systems = read_systems(DATASET_PATH)[:5] + systems = [system.to(torch.float32) for system in systems] + for system in systems: + get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + output = model( + systems, + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + + expected_output = torch.tensor( + [ + [-0.000364058418], + [0.009637143463], + [0.006566384807], + [-0.012186427601], + [0.008798411116], + ] + ) + + # if you need to change the hardcoded values: + # torch.set_printoptions(precision=12) + # print(output["mtt::U0"].block().values) + + torch.testing.assert_close(output["mtt::U0"].block().values, expected_output) + + +def test_regression_train(): + """Perform a regression test on the model when + trained for 2 epoch on a small dataset""" + + systems = read_systems(DATASET_PATH) + + conf = { + "mtt::U0": { + "quantity": "energy", + "read_from": DATASET_PATH, + "reader": "ase", + "key": "U0", + "unit": "eV", + "type": "scalar", + "per_atom": False, + "num_subtargets": 1, + "forces": False, + "stress": False, + "virial": False, + } + } + targets, target_info_dict = read_targets(OmegaConf.create(conf)) + dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) + + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["scheduler_patience"] = 1 + hypers["training"]["fixed_composition_weights"] = {} + loss_conf = {"energy": init_with_defaults(LossSpecification)} + loss_conf["energy"]["gradients"] = { + "positions": init_with_defaults(LossSpecification) + } + loss_conf = OmegaConf.create(loss_conf) + OmegaConf.resolve(loss_conf) + hypers["training"]["loss"] = loss_conf + hypers["training"]["num_epochs"] = 2 + hypers["training"]["compile"] = False + + dataset_info = DatasetInfo( + length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict + ) + model = PhACE(MODEL_HYPERS, dataset_info) + + hypers["training"]["num_epochs"] = 1 + trainer = Trainer(hypers["training"]) + trainer.train( + model=model, + dtype=torch.float32, + devices=[torch.device("cpu")], + train_datasets=[dataset], + val_datasets=[dataset], + checkpoint_dir=".", + ) + + # Predict on the first five systems + systems = [system.to(torch.float32) for system in systems] + for system in systems: + get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) + + model.module = model.fake_gradient_model + del model.gradient_model + del model.fake_gradient_model + model = torch.jit.script(model) + output = model( + systems[:5], + {"mtt::U0": ModelOutput(quantity="energy", unit="", per_atom=False)}, + ) + + expected_output = torch.tensor( + [ + [0.359379351139], + [0.278062015772], + [0.233197182417], + [0.495620548725], + [0.101988784969], + ] + ) + + # if you need to change the hardcoded values: + # torch.set_printoptions(precision=12) + # print(output["mtt::U0"].block().values) + + torch.testing.assert_close(output["mtt::U0"].block().values, expected_output) diff --git a/src/metatrain/experimental/phace/trainer.py b/src/metatrain/experimental/phace/trainer.py new file mode 100644 index 0000000000..84e85c5a1a --- /dev/null +++ b/src/metatrain/experimental/phace/trainer.py @@ -0,0 +1,649 @@ +import contextlib +import copy +import logging +import math +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union, cast + +import torch +import torch.distributed +from metatomic.torch import ModelOutput +from torch.fx.experimental.proxy_tensor import make_fx +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler + +from metatrain.utils.abc import ModelInterface, TrainerInterface +from metatrain.utils.additive import get_remove_additive_transform +from metatrain.utils.data import ( + CollateFn, + CombinedDataLoader, + Dataset, + get_num_workers, + unpack_batch, + validate_num_workers, +) +from metatrain.utils.distributed.slurm import DistributedEnvironment +from metatrain.utils.io import check_file_extension +from metatrain.utils.logging import ROOT_LOGGER, MetricLogger +from metatrain.utils.loss import LossAggregator, LossSpecification +from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric +from metatrain.utils.neighbor_lists import ( + get_requested_neighbor_lists, + get_system_with_neighbor_lists_transform, +) +from metatrain.utils.per_atom import average_by_num_atoms +from metatrain.utils.scaler import get_remove_scale_transform +from metatrain.utils.transfer import batch_to + +from . import checkpoints +from .documentation import TrainerHypers +from .model import PhACE +from .utils import InversionAugmenter, systems_to_batch + + +def _get_requested_outputs(targets, target_info_dict): + requested_outputs = {} + for name, target in targets.items(): + requested_outputs[name] = ModelOutput( + quantity=target_info_dict[name].quantity, + unit=target_info_dict[name].unit, + per_atom=target_info_dict[name].per_atom, + explicit_gradients=target.block(0).gradients_list(), + ) + return requested_outputs + + +@contextlib.contextmanager +def _disable_fx_duck_shape(): + init_duck_shape = torch.fx.experimental._config.use_duck_shape + torch.fx.experimental._config.use_duck_shape = False + try: + yield + finally: + torch.fx.experimental._config.use_duck_shape = init_duck_shape + + +def compile_model(model: PhACE, loader: torch.utils.data.DataLoader): + # inspired by the NequIP codebase + parameter_tensor = next(iter(model.parameters())) + dtype = parameter_tensor.dtype + device = parameter_tensor.device + batch = next(iter(loader)) + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + data = systems_to_batch(systems, model.requested_neighbor_lists()[0]) + with _disable_fx_duck_shape(): + fx_model = make_fx( + model.module, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + _error_on_data_dependent_ops=True, + )( + data, + [ + n + for n, o in _get_requested_outputs( + targets, model.dataset_info.targets + ).items() + if len(o.explicit_gradients) > 0 + ], + ) + compiled_module = torch.compile( + fx_model, + dynamic=True, + fullgraph=True, + mode="max-autotune", + ) + model.module = compiled_module + + +def get_scheduler( + optimizer: torch.optim.Optimizer, train_hypers: TrainerHypers, steps_per_epoch: int +) -> LambdaLR: + """ + Get a CosineAnnealing learning-rate scheduler with warmup + + :param optimizer: The optimizer for which to create the scheduler. + :param train_hypers: The training hyperparameters. + :param steps_per_epoch: The number of steps per epoch. + :return: The learning rate scheduler. + """ + total_steps = train_hypers["num_epochs"] * steps_per_epoch + warmup_steps = int(train_hypers["warmup_fraction"] * total_steps) + min_lr_ratio = 0.0 # hardcoded for now, could be made configurable in the future + + def lr_lambda(current_step: int) -> float: + if current_step < warmup_steps: + # Linear warmup + return float(current_step) / float(max(1, warmup_steps)) + else: + # Cosine decay + progress = (current_step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay + + scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) + return scheduler + + +class Trainer(TrainerInterface[TrainerHypers]): + __checkpoint_version__ = 1 + + def __init__(self, hypers: TrainerHypers) -> None: + super().__init__(hypers) + + self.optimizer_state_dict: Optional[Dict[str, Any]] = None + self.scheduler_state_dict: Optional[Dict[str, Any]] = None + self.epoch: Optional[int] = None + self.best_epoch: Optional[int] = None + self.best_metric: Optional[float] = None + self.best_model_state_dict: Optional[Dict[str, Any]] = None + self.best_optimizer_state_dict: Optional[Dict[str, Any]] = None + + def train( + self, + model: PhACE, + dtype: torch.dtype, + devices: List[torch.device], + train_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], + checkpoint_dir: str, + ) -> None: + assert dtype in PhACE.__supported_dtypes__ + + is_distributed = self.hypers["distributed"] + + if is_distributed: + if len(devices) > 1: + raise ValueError( + "Requested distributed training with the `multi-gpu` device. " + "If you want to run distributed training with PhACE, please " + "set `device` to cuda." + ) + # the calculation of the device number works both when GPUs on different + # processes are not visible to each other and when they are + distr_env = DistributedEnvironment(self.hypers["distributed_port"]) + device_number = distr_env.local_rank % torch.cuda.device_count() + device = torch.device("cuda", device_number) + torch.distributed.init_process_group(backend="nccl", device_id=device) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + rank = 0 + device = devices[0] + # only one device, as we don't support non-distributed multi-gpu for now + + if is_distributed: + logging.info(f"Training on {world_size} devices with dtype {dtype}") + else: + logging.info(f"Training on device {device} with dtype {dtype}") + + # Move the model to the device and dtype: + model.to(device=device, dtype=dtype) + # The additive models of the PhACE are always in float64 (to avoid + # numerical errors in the composition weights, which can be very large). + for additive_model in model.additive_models: + additive_model.to(dtype=torch.float64) + model.scaler.to(dtype=torch.float64) + + logging.info("Calculating composition weights") + model.additive_models[0].train_model( # this is the composition model + train_datasets, + model.additive_models[1:], + self.hypers["batch_size"], + is_distributed, + self.hypers["atomic_baseline"], + ) + + if self.hypers["scale_targets"]: + logging.info("Calculating scaling weights") + model.scaler.train_model( + train_datasets, + model.additive_models, + self.hypers["batch_size"], + is_distributed, + self.hypers["fixed_scaling_weights"], + ) + + logging.info("Setting up data loaders") + + if is_distributed: + train_samplers = [ + DistributedSampler( + train_dataset, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + ) + for train_dataset in train_datasets + ] + val_samplers = [ + DistributedSampler( + val_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + for val_dataset in val_datasets + ] + else: + train_samplers = [None] * len(train_datasets) + val_samplers = [None] * len(val_datasets) + + # Extract additive models and scaler and move them to CPU/float64 so they + # can be used in the collate function + model.additive_models[0].weights_to(device="cpu", dtype=torch.float64) + additive_models = copy.deepcopy( + model.additive_models.to(dtype=torch.float64, device="cpu") + ) + model.additive_models.to(device) + model.additive_models[0].weights_to(device=device, dtype=torch.float64) + model.scaler.scales_to(device="cpu", dtype=torch.float64) + scaler = copy.deepcopy(model.scaler.to(dtype=torch.float64, device="cpu")) + model.scaler.to(device) + model.scaler.scales_to(device=device, dtype=torch.float64) + + # Create collate functions: + dataset_info = model.dataset_info + train_targets = dataset_info.targets + extra_data_info = dataset_info.extra_data + inversion_augmenter = InversionAugmenter( + target_info_dict=train_targets, extra_data_info_dict=extra_data_info + ) + requested_neighbor_lists = get_requested_neighbor_lists(model) + collate_fn_train = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ + inversion_augmenter.apply_random_augmentations, + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + ) + collate_fn_val = CollateFn( + target_keys=list(train_targets.keys()), + callables=[ # no augmentation for validation + get_system_with_neighbor_lists_transform(requested_neighbor_lists), + get_remove_additive_transform(additive_models, train_targets), + get_remove_scale_transform(scaler), + ], + ) + + # Create dataloader for the training datasets: + if self.hypers["num_workers"] is None: + num_workers = get_num_workers() + logging.info( + "Number of workers for data-loading not provided and chosen " + f"automatically. Using {num_workers} workers." + ) + else: + num_workers = self.hypers["num_workers"] + validate_num_workers(num_workers) + + train_dataloaders = [] + for train_dataset, train_sampler in zip( + train_datasets, train_samplers, strict=True + ): + if len(train_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A training dataset has fewer samples " + f"({len(train_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + train_dataloaders.append( + DataLoader( + dataset=train_dataset, + batch_size=self.hypers["batch_size"], + sampler=train_sampler, + shuffle=( + # the sampler takes care of this (if present) + train_sampler is None + ), + drop_last=( + # the sampler takes care of this (if present) + train_sampler is None + ), + collate_fn=collate_fn_train, + num_workers=num_workers, + ) + ) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) + + # Create dataloader for the validation datasets: + val_dataloaders = [] + for val_dataset, val_sampler in zip(val_datasets, val_samplers, strict=True): + if len(val_dataset) < self.hypers["batch_size"]: + raise ValueError( + f"A validation dataset has fewer samples " + f"({len(val_dataset)}) than the batch size " + f"({self.hypers['batch_size']}). " + "Please reduce the batch size." + ) + val_dataloaders.append( + DataLoader( + dataset=val_dataset, + batch_size=self.hypers["batch_size"], + sampler=val_sampler, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_val, + num_workers=num_workers, + ) + ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) + + # by default, we initialize the model to use the gradient-free module; here we + # set it to use the gradient module for training so that we can compile + model.module = model.gradient_model + + if self.hypers["compile"]: + compile_model(model, train_dataloader) + + # For distributed training, we don't use DDP to avoid gradient bucketing + # which can break up the computational graph for the backward pass. + # Instead, we manually average gradients across processes. + if is_distributed: + world_size = torch.distributed.get_world_size() + + # Extract all the possible outputs and their gradients: + train_targets = model.dataset_info.targets + outputs_list = [] + for target_name, target_info in train_targets.items(): + outputs_list.append(target_name) + for gradient_name in target_info.gradients: + outputs_list.append(f"{target_name}_{gradient_name}_gradients") + + # Create a loss function: + loss_hypers = self.hypers["loss"] + loss_hypers = cast(Dict[str, LossSpecification], self.hypers["loss"]) # mypy + loss_fn = LossAggregator( + targets=train_targets, + config=loss_hypers, + ) + logging.info("Using the following loss functions:") + for name, info in loss_fn.metadata.items(): + logging.info(f"{name}:") + main = {k: v for k, v in info.items() if k != "gradients"} + logging.info(main) + if "gradients" not in info or len(info["gradients"]) == 0: + continue + logging.info("With gradients:") + for grad, ginfo in info["gradients"].items(): + logging.info(f"\t{name}::{grad}: {ginfo}") + + # Create an optimizer: + optimizer = torch.optim.Adam( + model.parameters(), lr=self.hypers["learning_rate"] + ) + if self.optimizer_state_dict is not None: + # try to load the optimizer state dict, but this is only possible + # if there are no new targets in the model (new parameters) + if not model.has_new_targets: + optimizer.load_state_dict(self.optimizer_state_dict) + + # Create a learning rate scheduler + lr_scheduler = get_scheduler(optimizer, self.hypers, len(train_dataloader)) + + if self.scheduler_state_dict is not None: + # same as the optimizer, try to load the scheduler state dict + if not model.has_new_targets: + lr_scheduler.load_state_dict(self.scheduler_state_dict) + + # per-atom targets: + per_structure_targets = self.hypers["per_structure_targets"] + + # Log the initial learning rate: + old_lr = optimizer.param_groups[0]["lr"] + logging.info(f"Initial learning rate: {old_lr}") + + start_epoch = 0 if self.epoch is None else self.epoch + 1 + + # Train the model: + if self.best_metric is None: + self.best_metric = float("inf") + logging.info("Starting training") + epoch = start_epoch + for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]): + if is_distributed: + for train_sampler in train_samplers: + train_sampler.set_epoch(epoch) + + train_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + val_rmse_calculator = RMSEAccumulator(self.hypers["log_separate_blocks"]) + if self.hypers["log_mae"]: + train_mae_calculator = MAEAccumulator( + self.hypers["log_separate_blocks"] + ) + val_mae_calculator = MAEAccumulator(self.hypers["log_separate_blocks"]) + + train_loss = 0.0 + for batch in train_dataloader: + optimizer.zero_grad() + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + + predictions = model( + systems, _get_requested_outputs(targets, dataset_info.targets) + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms(targets, systems, per_structure_targets) + + train_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # make sure all parameters contribute to the gradient calculation + # to make sure all parameters have a gradient + for param in model.parameters(): + train_loss_batch += 0.0 * param.sum() + + train_loss_batch.backward() + + # In distributed training, manually average gradients across processes + # instead of using DDP to avoid gradient bucketing breaking the graph + if is_distributed: + for param in model.parameters(): + torch.distributed.all_reduce(param.grad) + param.grad /= world_size + + if self.hypers["gradient_clipping"] is not None: + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.hypers["gradient_clipping"] + ) + optimizer.step() + lr_scheduler.step() + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(train_loss_batch) + train_loss += train_loss_batch.item() + + scaled_predictions = model.scaler(systems, predictions) + scaled_targets = model.scaler(systems, targets) + train_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + train_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_train_info = train_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_train_info.update( + train_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + with torch.no_grad(): + val_loss = 0.0 + for batch in val_dataloader: + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + + predictions = model( + systems, _get_requested_outputs(targets, dataset_info.targets) + ) + + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + + val_loss_batch = loss_fn(predictions, targets, extra_data) + + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + scaled_predictions = model.scaler(systems, predictions) + scaled_targets = model.scaler(systems, targets) + val_rmse_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) + + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + if self.hypers["log_mae"]: + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) + ) + + # Now we log the information: + finalized_train_info = {"loss": train_loss, **finalized_train_info} + finalized_val_info = {"loss": val_loss, **finalized_val_info} + + if epoch == start_epoch: + metric_logger = MetricLogger( + log_obj=ROOT_LOGGER, + dataset_info=model.dataset_info, + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], + ) + if epoch % self.hypers["log_interval"] == 0: + metric_logger.log( + metrics=[finalized_train_info, finalized_val_info], + epoch=epoch, + rank=rank, + learning_rate=optimizer.param_groups[0]["lr"], + ) + + val_metric = get_selected_metric( + finalized_val_info, self.hypers["best_model_metric"] + ) + if val_metric < self.best_metric: + self.best_metric = val_metric + self.best_model_state_dict = copy.deepcopy(model.state_dict()) + self.best_epoch = epoch + self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + if epoch % self.hypers["checkpoint_interval"] == 0: + if is_distributed: + torch.distributed.barrier() + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + self.epoch = epoch + if rank == 0: + self.save_checkpoint( + model, + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) + + # prepare for the checkpoint that will be saved outside the function + self.epoch = epoch + self.optimizer_state_dict = optimizer.state_dict() + self.scheduler_state_dict = lr_scheduler.state_dict() + + if is_distributed: + torch.distributed.destroy_process_group() + + def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: + checkpoint = model.get_checkpoint() + checkpoint.update( + { + "train_hypers": self.hypers, + "trainer_ckpt_version": self.__checkpoint_version__, + "epoch": self.epoch, + "optimizer_state_dict": self.optimizer_state_dict, + "scheduler_state_dict": self.scheduler_state_dict, + "best_epoch": self.best_epoch, + "best_metric": self.best_metric, + "best_model_state_dict": self.best_model_state_dict, + "best_optimizer_state_dict": self.best_optimizer_state_dict, + } + ) + torch.save( + checkpoint, + check_file_extension(path, ".ckpt"), + ) + + @classmethod + def load_checkpoint( + cls, + checkpoint: Dict[str, Any], + hypers: TrainerHypers, + context: Literal["restart", "finetune"], # not used at the moment + ) -> "Trainer": + trainer = cls(hypers) + trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"] + trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"] + trainer.epoch = checkpoint["epoch"] + trainer.best_epoch = checkpoint["best_epoch"] + trainer.best_metric = checkpoint["best_metric"] + trainer.best_model_state_dict = checkpoint["best_model_state_dict"] + trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"] + + return trainer + + @classmethod + def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: + for v in range(1, cls.__checkpoint_version__): + if checkpoint["trainer_ckpt_version"] == v: + update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}") + update(checkpoint) + checkpoint["trainer_ckpt_version"] = v + 1 + + if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__: + raise RuntimeError( + f"Unable to upgrade the checkpoint: the checkpoint is using trainer " + f"version {checkpoint['trainer_ckpt_version']}, while the current " + f"trainer version is {cls.__checkpoint_version__}." + ) + return checkpoint diff --git a/src/metatrain/experimental/phace/utils.py b/src/metatrain/experimental/phace/utils.py new file mode 100644 index 0000000000..ce04842e8e --- /dev/null +++ b/src/metatrain/experimental/phace/utils.py @@ -0,0 +1,440 @@ +import random +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import TensorBlock, TensorMap +from metatomic.torch import NeighborListOptions, System, register_autograd_neighbors + +from metatrain.utils import torch_jit_script_unless_coverage +from metatrain.utils.data import TargetInfo + + +def systems_to_batch( + systems: List[System], nl_options: NeighborListOptions +) -> Dict[str, torch.Tensor]: + """ + Convert a list of System objects directly to a GNN-batch-like dictionary. + + This function creates a torch-compile-friendly batch representation with + stacked positions, cells, number of atoms per structure, atomic types, + and center/neighbor indices. + + :param systems: List of System objects to batch + :param nl_options: Neighbor list options to extract neighbor information + :return: Dictionary containing batched tensors: + - positions: stacked positions of all atoms [N_total, 3] + - cells: stacked unit cells [N_structures, 3, 3] + - species: atomic types of all atoms [N_total] + - n_atoms: number of atoms per structure [N_structures] + - cell_shifts: cell shift vectors for all pairs [N_pairs, 3] + - centers: local atom indices within each structure [N_total] + - center_indices: global center indices for all pairs [N_pairs] + - neighbor_indices: global neighbor indices for all pairs [N_pairs] + - structure_centers: structure index for each atom [N_total] + - structure_pairs: structure index for each pair [N_pairs] + - structure_offsets: cumulative atom offsets per structure [N_structures] + """ + device = systems[0].positions.device + + positions_list = [] + species_list = [] + cells_list = [] + n_atoms_list: List[int] = [] + edge_index_list = [] + cell_shifts_list = [] + centers_list = [] + structures_centers_list = [] + structure_pairs_list = [] + + cumulative_atoms = 0 + for i, system in enumerate(systems): + n_atoms_i = len(system.positions) + n_atoms_list.append(n_atoms_i) + + positions_list.append(system.positions) + species_list.append(system.types) + cells_list.append(system.cell) + + nl = system.get_neighbor_list(nl_options) + samples = nl.samples.values + edge_indices = samples[:, :2] # local center/neighbor indices + cell_shifts_item = samples[:, 2:] + + # Create global indices by adding cumulative offset + global_center_indices = edge_indices[:, 0] + cumulative_atoms + global_neighbor_indices = edge_indices[:, 1] + cumulative_atoms + + edge_index_list.append( + torch.stack([global_center_indices, global_neighbor_indices], dim=1) + ) + cell_shifts_list.append(cell_shifts_item) + + centers_list.append(torch.arange(n_atoms_i, device=device, dtype=torch.int32)) + structures_centers_list.append( + torch.full((n_atoms_i,), i, device=device, dtype=torch.int32) + ) + structure_pairs_list.append( + torch.full((len(edge_indices),), i, device=device, dtype=torch.int32) + ) + + cumulative_atoms += n_atoms_i + + positions = torch.cat(positions_list, dim=0) + species = torch.cat(species_list, dim=0) + cells = torch.stack(cells_list, dim=0) + n_atoms = torch.tensor(n_atoms_list, device=device, dtype=torch.int64) + pairs = torch.cat(edge_index_list, dim=0) + cell_shifts = torch.cat(cell_shifts_list, dim=0) + centers = torch.cat(centers_list, dim=0) + structure_centers = torch.cat(structures_centers_list, dim=0) + structure_pairs = torch.cat(structure_pairs_list, dim=0) + + # Compute structure offsets (cumulative sum of n_atoms, starting with 0) + structure_offsets = torch.zeros(len(systems), device=device, dtype=torch.int32) + if len(systems) > 1: + structure_offsets[1:] = torch.cumsum(n_atoms[:-1].to(torch.int32), dim=0) + + batch_dict = { + "positions": positions, + "cells": cells, + "species": species, + "n_atoms": n_atoms, + "cell_shifts": cell_shifts, + "centers": centers, + "center_indices": pairs[:, 0], + "neighbor_indices": pairs[:, 1], + "structure_centers": structure_centers, + "structure_pairs": structure_pairs, + "structure_offsets": structure_offsets, + } + + return batch_dict + + +def get_random_inversion() -> int: + """ + Randomly choose an inversion factor (-1 or 1). + + :return: either -1 or 1 + """ + return random.choice([1, -1]) + + +class InversionAugmenter: + """ + A class to apply random inversions to a set of systems and their targets. + + :param target_info_dict: A dictionary mapping target names to their corresponding + :class:`TargetInfo` objects. This is used to determine the type of targets and + how to apply the augmentations. + :param extra_data_info_dict: An optional dictionary mapping extra data names to + their corresponding :py:class:`TargetInfo` objects. This is used to determine + the type of extra data and how to apply the augmentations. + """ + + def __init__( + self, + target_info_dict: Dict[str, TargetInfo], + extra_data_info_dict: Optional[Dict[str, TargetInfo]] = None, + ): + # checks on targets + for target_info in target_info_dict.values(): + if target_info.is_cartesian: + if len(target_info.layout.block(0).components) > 2: + raise ValueError( + "InversionAugmenter only supports Cartesian targets " + "with `rank<=2`." + ) + + self.target_info_dict = target_info_dict + if extra_data_info_dict is None: + extra_data_info_dict = {} + self.extra_data_info_dict = extra_data_info_dict + + def apply_random_augmentations( + self, + systems: List[System], + targets: Dict[str, TensorMap], + extra_data: Optional[Dict[str, TensorMap]] = None, + ) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Applies random augmentations to a number of ``System`` objects, their targets, + and optionally extra data. + + :param systems: A list of :py:class:`System` objects to be augmented. + :param targets: A dictionary mapping target names to their corresponding + :py:class:`TensorMap` objects. These are the targets to be augmented. + :param extra_data: An optional dictionary mapping extra data names to their + corresponding :class:`TensorMap` objects. This extra data will also be + augmented if provided. + + :return: A tuple containing the augmented systems and targets. + """ + inversions = [get_random_inversion() for _ in range(len(systems))] + return self.apply_augmentations( + systems, targets, inversions, extra_data=extra_data + ) + + def apply_augmentations( + self, + systems: List[System], + targets: Dict[str, TensorMap], + inversions: List[int], + extra_data: Optional[Dict[str, TensorMap]] = None, + ) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Applies augmentations to a number of ``System`` objects, their targets, and + optionally extra data. The augmentations are defined by a list of rotations + and a list of inversions. + + :param systems: A list of :py:class:`System` objects to be augmented. + :param targets: A dictionary mapping target names to their corresponding + :py:class:`TensorMap` objects. These are the targets to be augmented. + :param rotations: A list of :class:`scipy.spatial.transform.Rotation` objects + representing the rotations to be applied to each system. + :param inversions: A list of integers (1 or -1) representing the + inversion factors to be applied to each system. + :param extra_data: An optional dictionary mapping extra data names to their + corresponding :class:`TensorMap` objects. This extra data will also be + augmented if provided. + + :return: A tuple containing the augmented systems and targets. + """ + self._validate(systems, inversions) + + return _apply_augmentations(systems, targets, inversions, extra_data=extra_data) + + def _validate(self, systems: List[System], inversions: List[int]) -> None: + if len(inversions) != len(systems): + raise ValueError( + "The number of inversions must match the number of systems." + ) + if any(i not in [1, -1] for i in inversions): + raise ValueError("Inversions must be either 1 or -1.") + + +def _apply_inversions_to_spherical_tensor_map( + systems: List[System], + target_tmap: TensorMap, + inversions: List[int], +) -> TensorMap: + new_blocks: List[TensorBlock] = [] + for key, block in target_tmap.items(): + ell, sigma = int(key[0]), int(key[1]) + values = block.values + if "atom" in block.samples.names: + split_values = torch.split( + values, [len(system.positions) for system in systems] + ) + else: + split_values = torch.split(values, [1 for _ in systems]) + new_values = [] + ell = (len(block.components[0]) - 1) // 2 + for v, i in zip(split_values, inversions, strict=True): + is_inverted = i == -1 + new_v = v.clone() + if is_inverted: # inversion + new_v = new_v * (-1) ** ell * sigma + new_values.append(new_v) + new_values = torch.concatenate(new_values) + new_block = TensorBlock( + values=new_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + new_blocks.append(new_block) + + return TensorMap( + keys=target_tmap.keys, + blocks=new_blocks, + ) + + +@torch_jit_script_unless_coverage # script for speed +def _apply_augmentations( + systems: List[System], + targets: Dict[str, TensorMap], + inversions: List[int], + extra_data: Optional[Dict[str, TensorMap]] = None, +) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]: + # Apply the transformations to the systems + + new_systems: List[System] = [] + for system, i in zip(systems, inversions, strict=True): + new_system = System( + positions=system.positions * i, + types=system.types, + cell=system.cell * i, + pbc=system.pbc, + ) + for data_name in system.known_data(): + data = system.get_data(data_name) + # check if this data is easy to handle (scalar/vector), otherwise error out + if len(data) != 1: + raise ValueError( + f"System data '{data_name}' has {len(data)} blocks, which is not " + "supported by RotationalAugmenter. Only scalar and vector data are " + "supported." + ) + if len(data.block().components) == 0: + # scalar data, no change + new_system.add_data(data_name, data) + elif len(data.block().components) == 1 and data.block().components[ + 0 + ].names == ["xyz"]: + # this assumes that this is a proper vector (quite safe) + new_system.add_data( + data_name, + TensorMap( + keys=data.keys, + blocks=[ + TensorBlock( + values=data.block().values * i, + samples=data.block().samples, + components=data.block().components, + properties=data.block().properties, + ) + ], + ), + ) + else: + raise ValueError( + f"System data '{data_name}' has components " + f"{data.block().components}, which are not supported by " + "InversionAugmenter. Only scalar and vector data are supported." + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = (neighbors.values.squeeze(-1) * i).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + new_system.add_neighbor_list(options, neighbors) + new_systems.append(new_system) + + # Apply the transformation to the targets and extra data + new_targets: Dict[str, TensorMap] = {} + new_extra_data: Dict[str, TensorMap] = {} + + # Do not transform any masks present in extra_data + if extra_data is not None: + mask_keys: List[str] = [] + for key in extra_data.keys(): + if key.endswith("_mask"): + mask_keys.append(key) + for key in mask_keys: + new_extra_data[key] = extra_data.pop(key) + + for tensormap_dict, new_dict in zip( + [targets, extra_data], [new_targets, new_extra_data], strict=True + ): + if tensormap_dict is None: + continue + assert tensormap_dict is not None + for name, original_tmap in tensormap_dict.items(): + is_scalar = False + if len(original_tmap.blocks()) == 1: + if len(original_tmap.block().components) == 0: + is_scalar = True + + is_cartesian = False + if len(original_tmap.blocks()) == 1: + if len(original_tmap.block().components) > 0: + if "xyz" in original_tmap.block().components[0].names[0]: + is_cartesian = True + + is_spherical = all( + len(block.components) == 1 and block.components[0].names == ["o3_mu"] + for block in original_tmap.blocks() + ) + + if is_scalar: + # no change for energies + energy_block = TensorBlock( + values=original_tmap.block().values, + samples=original_tmap.block().samples, + components=original_tmap.block().components, + properties=original_tmap.block().properties, + ) + if original_tmap.block().has_gradient("positions"): + # transform position gradients: + block = original_tmap.block().gradient("positions") + position_gradients = block.values.squeeze(-1) + split_sizes_forces = [ + system.positions.shape[0] for system in systems + ] + split_position_gradients = torch.split( + position_gradients, split_sizes_forces + ) + position_gradients = torch.cat( + [ + split_position_gradients[i] * inversions[i] + for i in range(len(systems)) + ] + ) + energy_block.add_gradient( + "positions", + TensorBlock( + values=position_gradients.unsqueeze(-1), + samples=block.samples, + components=block.components, + properties=block.properties, + ), + ) + if original_tmap.block().has_gradient("strain"): + # transform strain gradients (rank-2 tensor), unchanged: + block = original_tmap.block().gradient("strain") + energy_block.add_gradient( + "strain", + TensorBlock( + values=block.values, + samples=block.samples, + components=block.components, + properties=block.properties, + ), + ) + new_dict[name] = TensorMap( + keys=original_tmap.keys, + blocks=[energy_block], + ) + + elif is_spherical: + new_dict[name] = _apply_inversions_to_spherical_tensor_map( + systems, original_tmap, inversions + ) + + elif is_cartesian: + rank = len(original_tmap.block().components) + if rank == 1: + # transform Cartesian vector (assume proper, quite safe) + block = original_tmap.block() + vectors = block.values + if "atom" in original_tmap.block().samples.names: + split_vectors = torch.split( + vectors, [len(system.positions) for system in systems] + ) + else: + split_vectors = torch.split(vectors, [1 for _ in systems]) + new_vectors = [] + for v, i in zip(split_vectors, inversions, strict=True): + new_v = v * i + new_vectors.append(new_v) + new_vectors = torch.cat(new_vectors) + new_dict[name] = TensorMap( + keys=original_tmap.keys, + blocks=[ + TensorBlock( + values=new_vectors, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ], + ) + elif rank == 2: + # assume proper tensor (quite safe), unchanged + new_dict[name] = original_tmap + + return new_systems, new_targets, new_extra_data diff --git a/src/metatrain/pet/tests/test_basic.py b/src/metatrain/pet/tests/test_basic.py index 6952a559fe..6fb290a935 100644 --- a/src/metatrain/pet/tests/test_basic.py +++ b/src/metatrain/pet/tests/test_basic.py @@ -36,7 +36,8 @@ class TestInput(InputTests, PETTests): ... class TestOutput(OutputTests, PETTests): - is_equivariant_model = False + is_equivariant_rotations = False + is_equivariant_reflections = False @pytest.fixture def n_features(self, model_hypers): diff --git a/src/metatrain/soap_bpnn/trainer.py b/src/metatrain/soap_bpnn/trainer.py index 0cd6a22571..c874ca33fe 100644 --- a/src/metatrain/soap_bpnn/trainer.py +++ b/src/metatrain/soap_bpnn/trainer.py @@ -35,9 +35,7 @@ ) from metatrain.utils.per_atom import average_by_num_atoms from metatrain.utils.scaler import get_remove_scale_transform -from metatrain.utils.transfer import ( - batch_to, -) +from metatrain.utils.transfer import batch_to from . import checkpoints from .documentation import TrainerHypers @@ -108,7 +106,7 @@ def train( if len(devices) > 1: raise ValueError( "Requested distributed training with the `multi-gpu` device. " - " If you want to run distributed training with SOAP-BPNN, please " + "If you want to run distributed training with SOAP-BPNN, please " "set `device` to cuda." ) # the calculation of the device number works both when GPUs on different diff --git a/src/metatrain/utils/testing/mtt_plugin.py b/src/metatrain/utils/testing/mtt_plugin.py index cadc3aaa7f..afe8f32066 100644 --- a/src/metatrain/utils/testing/mtt_plugin.py +++ b/src/metatrain/utils/testing/mtt_plugin.py @@ -41,7 +41,7 @@ def pytest_runtest_makereport(item: Any, call: Any) -> Generator: tb = longrepr.reprtraceback message = ( - f"\nCheckout this test's documentation to understand it: \n{doc_url}\n" + f"\nCheck out this test's documentation to understand it: \n{doc_url}\n" ) # Add our link *inside* the traceback display diff --git a/src/metatrain/utils/testing/output.py b/src/metatrain/utils/testing/output.py index 929d2e3d15..e12533a18f 100644 --- a/src/metatrain/utils/testing/output.py +++ b/src/metatrain/utils/testing/output.py @@ -48,9 +48,12 @@ class OutputTests(ArchitectureTests): """Whether the model supports returning features.""" supports_last_layer_features: bool = True """Whether the model supports returning last-layer features.""" - is_equivariant_model: bool = True + is_equivariant_rotations: bool = True """Whether the model is equivariant (i.e. produces outputs that transform correctly under rotations by architecture's design).""" + is_equivariant_reflections: bool = True + """Whether the model is equivariant (i.e. produces outputs that + transform correctly under reflections by architecture's design).""" @pytest.fixture def n_features(self) -> Optional[int | list[int]]: @@ -547,7 +550,7 @@ def test_output_last_layer_features( assert "energy" in outputs assert "mtt::aux::energy_last_layer_features" in outputs - last_layer_features = outputs["mtt::aux::energy_last_layer_features"].block() + last_layer_features = outputs["mtt::aux::energy_last_layer_features"].block(0) expected_samples = ["system", "atom"] if per_atom else ["system"] assert last_layer_features.samples.names == expected_samples assert last_layer_features.properties.names == ["feature"] @@ -606,7 +609,7 @@ def test_output_last_layer_features_selected_atoms( ) model = model.to(systems[0].positions.dtype) out = model(systems, outputs, selected_atoms=selected_atoms) - features = out[output_label].block().samples.values + features = out[output_label].block(0).samples.values assert features.shape == selected_atoms.values.shape def test_single_atom( @@ -645,14 +648,14 @@ def test_output_scalar_invariant( This test is skipped if the model does not support scalar outputs, or if the model is not equivariant by design, i.e., if either - ``supports_scalar_outputs`` or ``is_equivariant_model`` is set to + ``supports_scalar_outputs`` or ``is_equivariant_rotations`` is set to ``False``. :param model_hypers: Hyperparameters to initialize the model. :param dataset_info: Dataset information to initialize the model. :param dataset_path: Path to a dataset file to read systems from. """ - if not self.supports_scalar_outputs or not self.is_equivariant_model: + if not self.supports_scalar_outputs or not self.is_equivariant_rotations: pytest.skip( f"{self.architecture} does not produce invariant scalar outputs." ) @@ -691,7 +694,7 @@ def test_output_spherical_equivariant_rotations( This test is skipped if the model does not support spherical outputs, or if the model is not equivariant by design, i.e., if either - ``supports_spherical_outputs`` or ``is_equivariant_model`` is set to + ``supports_spherical_outputs`` or ``is_equivariant_rotations`` is set to ``False``. :param model_hypers: Hyperparameters to initialize the model. @@ -699,7 +702,7 @@ def test_output_spherical_equivariant_rotations( :param dataset_path: Path to a dataset file to read systems from. """ - if not self.supports_spherical_outputs or not self.is_equivariant_model: + if not self.supports_spherical_outputs or not self.is_equivariant_rotations: pytest.skip( f"{self.architecture} does not produce equivariant spherical outputs." ) @@ -754,7 +757,7 @@ def test_output_spherical_equivariant_inversion( This test is skipped if the model does not support spherical outputs, or if the model is not equivariant by design, i.e., if either - ``supports_spherical_outputs`` or ``is_equivariant_model`` is set to + ``supports_spherical_outputs`` or ``is_equivariant_reflections`` is set to ``False``. :param model_hypers: Hyperparameters to initialize the model. @@ -764,7 +767,7 @@ def test_output_spherical_equivariant_inversion( :param o3_sigma: The O(3) sigma of the spherical output to test. """ - if not self.supports_spherical_outputs or not self.is_equivariant_model: + if not self.supports_spherical_outputs or not self.is_equivariant_reflections: pytest.skip( f"{self.architecture} does not produce equivariant spherical outputs." ) diff --git a/tests/utils/test_architectures.py b/tests/utils/test_architectures.py index 8b82b0effa..5be7df775a 100644 --- a/tests/utils/test_architectures.py +++ b/tests/utils/test_architectures.py @@ -24,13 +24,14 @@ def is_None(*args, **kwargs) -> None: def test_find_all_architectures(): all_arches = find_all_architectures() - assert len(all_arches) == 8 + assert len(all_arches) == 9 assert "gap" in all_arches assert "pet" in all_arches assert "soap_bpnn" in all_arches assert "deprecated.nanopet" in all_arches assert "experimental.flashmd" in all_arches + assert "experimental.phace" in all_arches assert "experimental.classifier" in all_arches assert "llpr" in all_arches assert "experimental.mace" in all_arches diff --git a/tox.ini b/tox.ini index 7b39199e67..e167b64ff4 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ envlist = llpr-tests classifier-tests mace-tests + phace-tests [testenv] package = wheel @@ -199,6 +200,17 @@ changedir = src/metatrain/experimental/flashmd/tests/ commands = pytest {posargs} +[testenv:phace-tests] +description = Run PhACE tests with pytest +passenv = * +deps = + pytest + spherical # tensor target tests +extras = phace +changedir = src/metatrain/experimental/phace/tests/ +commands = + pytest {posargs} + [testenv:classifier-tests] description = Run Classifier tests with pytest passenv = *